Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import os | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from PIL import Image | |
| import xml.etree.ElementTree as ET | |
| import torch.optim as optim | |
| from torch import nn | |
| # Your model training and evaluation functions (already defined in your previous code) | |
| # Define the custom dataset | |
| class FaceMaskDataset(Dataset): | |
| def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)): | |
| self.images_dir = images_dir | |
| self.annotations_dir = annotations_dir | |
| self.transform = transform | |
| self.resize = resize | |
| self.image_files = os.listdir(images_dir) | |
| def __len__(self): | |
| return len(self.image_files) | |
| def __getitem__(self, idx): | |
| image_path = os.path.join(self.images_dir, self.image_files[idx]) | |
| image = Image.open(image_path).convert("RGB") | |
| image = image.resize(self.resize) | |
| annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml")) | |
| if not os.path.exists(annotation_path): | |
| print(f"Warning: Annotation file {annotation_path} does not exist. Skipping image {self.image_files[idx]}.") | |
| return None, None # Return None if annotation is missing | |
| boxes, labels = self.load_annotations(annotation_path) | |
| if boxes is None or labels is None: | |
| return None, None # Skip if annotations are invalid | |
| target = {'boxes': boxes, 'labels': labels} | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, target | |
| def load_annotations(self, annotation_path): | |
| tree = ET.parse(annotation_path) | |
| root = tree.getroot() | |
| boxes = [] | |
| labels = [] | |
| for obj in root.iter('object'): | |
| label = obj.find('name').text | |
| bndbox = obj.find('bndbox') | |
| xmin = float(bndbox.find('xmin').text) | |
| ymin = float(bndbox.find('ymin').text) | |
| xmax = float(bndbox.find('xmax').text) | |
| ymax = float(bndbox.find('ymax').text) | |
| boxes.append([xmin, ymin, xmax, ymax]) | |
| labels.append(1 if label == "mask" else 0) # "mask" = 1, "no_mask" = 0 | |
| if len(boxes) == 0 or len(labels) == 0: | |
| return None, None # If no boxes or labels, return None | |
| boxes = torch.as_tensor(boxes, dtype=torch.float32) | |
| labels = torch.tensor(labels, dtype=torch.int64) | |
| return boxes, labels | |
| # Model Training Loop (referred to from previous code) | |
| def train_model(model, train_loader, val_loader, optimizer, num_epochs=10): | |
| for epoch in range(num_epochs): | |
| # Training loop | |
| running_loss = 0.0 | |
| model.train() | |
| for images, targets in train_loader: | |
| if images is None or targets is None: | |
| continue # Skip invalid images/annotations | |
| # Move data to device | |
| images = [image.to(device) for image in images] | |
| targets = [{k: v.to(device) for k, v in t.items()} for t in targets] | |
| optimizer.zero_grad() | |
| loss_dict = model(images, targets) | |
| # Calculate total loss | |
| total_loss = sum(loss for loss in loss_dict.values()) | |
| total_loss.backward() | |
| optimizer.step() | |
| running_loss += total_loss.item() | |
| print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}") | |
| # Evaluate after every epoch | |
| val_loss = evaluate_model(model, val_loader) | |
| print(f"Validation Loss: {val_loss}") | |
| # Validation function | |
| def evaluate_model(model, val_loader): | |
| model.eval() | |
| running_loss = 0.0 | |
| with torch.no_grad(): | |
| for images, targets in val_loader: | |
| if images is None or targets is None: | |
| continue # Skip invalid data | |
| # Move data to device | |
| images = [image.to(device) for image in images] | |
| targets = [{k: v.to(device) for k, v in t.items()} for t in targets] | |
| loss_dict = model(images, targets) | |
| # Calculate total loss | |
| total_loss = sum(loss for loss in loss_dict.values()) | |
| running_loss += total_loss.item() | |
| return running_loss / len(val_loader) | |
| # Function to upload dataset and start training | |
| def train_on_uploaded_data(train_data, val_data): | |
| # Save the uploaded dataset (files) | |
| train_data_path = "train_data.zip" | |
| val_data_path = "val_data.zip" | |
| # Unzip and prepare directories (assuming you upload zip files for simplicity) | |
| with open(train_data.name, 'wb') as f: | |
| f.write(train_data.read()) | |
| with open(val_data.name, 'wb') as f: | |
| f.write(val_data.read()) | |
| # Extract zip files | |
| os.system(f"unzip {train_data_path} -d ./train/") | |
| os.system(f"unzip {val_data_path} -d ./val/") | |
| # Load datasets | |
| train_dataset = FaceMaskDataset( | |
| images_dir="train/images", | |
| annotations_dir="train/annotations", | |
| transform=transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]) | |
| ) | |
| val_dataset = FaceMaskDataset( | |
| images_dir="val/images", | |
| annotations_dir="val/annotations", | |
| transform=transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]) | |
| ) | |
| # Dataloaders | |
| train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) | |
| val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn) | |
| # Train the model | |
| model = get_model(num_classes=2) # Assuming you have a model function | |
| model.to(device) | |
| optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005) | |
| # Train the model and return feedback | |
| train_model(model, train_loader, val_loader, optimizer, num_epochs=10) | |
| return "Training completed and model saved." | |
| # Create Gradio Interface | |
| iface = gr.Interface( | |
| fn=train_on_uploaded_data, | |
| inputs=[ | |
| gr.File(label="Upload Train Dataset (ZIP)"), | |
| gr.File(label="Upload Validation Dataset (ZIP)") | |
| ], | |
| outputs=gr.Textbox(label="Training Status"), | |
| live=True | |
| ) | |
| # Launch Gradio interface | |
| iface.launch() | |