Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader, random_split | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| # -------- CONFIG -------- | |
| data_dir = "D:/Dataset/face_age" | |
| checkpoint_path = "D:/Dataset/age_prediction_model2.pth" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # -------- SIMPLE CNN MODEL -------- | |
| class SimpleCNN(nn.Module): | |
| def __init__(self): | |
| super(SimpleCNN, self).__init__() | |
| self.features = nn.Sequential( | |
| nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(), | |
| nn.MaxPool2d(2), # 64x64 | |
| nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), | |
| nn.MaxPool2d(2), # 32x32 | |
| nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), | |
| nn.MaxPool2d(2), # 16x16 | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(128 * 16 * 16, 256), nn.ReLU(), | |
| nn.Linear(256, 1) # Output: age (regression) | |
| ) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = self.classifier(x) | |
| return x | |
| # -------- LOAD MODEL -------- | |
| model = SimpleCNN().to(device) | |
| # Check if checkpoint exists before loading | |
| if os.path.exists(checkpoint_path): | |
| model.load_state_dict(torch.load(checkpoint_path)) | |
| model.eval() # Set the model to evaluation mode | |
| print(f"Model loaded from {checkpoint_path}") | |
| else: | |
| print(f"Error: Checkpoint file not found at {checkpoint_path}. Please check the path.") | |
| # -------- PREPROCESSING -------- | |
| transform = transforms.Compose([ | |
| transforms.Resize((128, 128)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) | |
| ]) | |
| # -------- PREDICTION FUNCTION -------- | |
| def predict_age(image): | |
| image = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(image) | |
| age = output.item() # Convert to a single scalar | |
| return f"Predicted Age: {age:.2f}" | |
| # -------- GRADIO INTERFACE -------- | |
| iface = gr.Interface( | |
| fn=predict_age, | |
| inputs=gr.inputs.Image(shape=(128, 128), image_mode='RGB', source='upload'), | |
| outputs="text", | |
| title="Age Prediction Model", | |
| description="Upload an image to predict the age.", | |
| live=True | |
| ) | |
| iface.launch() | |