Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from torch.utils.data import Dataset, DataLoader, Subset, random_split | |
| from torchvision import transforms | |
| import pandas as pd | |
| from PIL import Image | |
| import time | |
| import numpy as np | |
| import math | |
| # import wandb | |
| from datasets import load_dataset | |
| import os | |
| #define the model | |
| class DinoRegressionHeteroImages(nn.Module): | |
| def __init__(self, dino_model, hidden_dim=128, dropout=0.1, dino_dim=1024): | |
| super().__init__() | |
| self.dino = dino_model # ViT backbone (pre‑trained Dinov2) | |
| for p in self.dino.parameters(): | |
| p.requires_grad = False | |
| # **KEEP THE SAME LAYER NAMES AS THE EMBEDDING‑ONLY MODEL** | |
| self.embedding_to_hidden = nn.Linear(dino_dim, hidden_dim) | |
| self.leaky_relu = nn.LeakyReLU() | |
| self.dropout = nn.Dropout(dropout) | |
| self.hidden_to_hidden = nn.Linear(hidden_dim, hidden_dim) | |
| self.out_mu = nn.Linear(hidden_dim, 1) | |
| self.out_logvar = nn.Linear(hidden_dim, 1) | |
| def forward(self, x): | |
| h = self.dino(x) # [B, dino_dim] | |
| h = self.embedding_to_hidden(h) | |
| h = self.leaky_relu(h) | |
| h = self.dropout(h) | |
| h = self.hidden_to_hidden(h) | |
| h = self.leaky_relu(h) | |
| mu = self.out_mu(h).squeeze(1) | |
| logvar = self.out_logvar(h).squeeze(1) | |
| logvar = torch.clamp(logvar, -10.0, 3.0) # σ ~ [0.005, 20] | |
| return mu, logvar | |
| # Standard image transform | |
| imgtransform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.Lambda(lambda x: x.convert('RGB')), # Ensure images are in RGB format | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # This uses the Huggingface dataset library to load the dataset. | |
| class LifespanDataset(Dataset): | |
| def __init__(self, split="train", | |
| repo_id="TristanKE/RemainingLifespanPredictionFaces", | |
| transform=None): | |
| self.ds = load_dataset(repo_id, split=split) | |
| self.transform = transform | |
| remaining = np.array(self.ds["remaining_lifespan"], dtype=np.float32) | |
| self.lifespan_mean = float(remaining.mean()) | |
| self.lifespan_std = float(remaining.std()) | |
| def __len__(self): | |
| return len(self.ds) | |
| def __getitem__(self, idx): | |
| ex = self.ds[idx] # dict with keys: image, remaining_lifespan, … | |
| img = ex["image"] # PIL.Image | |
| if self.transform: | |
| img = self.transform(img) | |
| target = (ex["remaining_lifespan"] - self.lifespan_mean) / self.lifespan_std | |
| return img, torch.tensor(target, dtype=torch.float32) | |
| # Gaussian Negative Log Likelihood loss | |
| def heteroscedastic_nll(y, mu, logvar): | |
| inv_var = torch.exp(-logvar) | |
| return (0.5 * inv_var * (y - mu) ** 2 + 0.5 * logvar).mean() | |
| # Cosine learning rate scheduler | |
| def cosine_schedule(epoch, total_epochs): | |
| return 0.5 * (1 + math.cos(math.pi * epoch / total_epochs)) | |
| # Main training loop | |
| if __name__ == "__main__": | |
| # Configuration, here you can change most things including the dataset | |
| cfg = { | |
| "N_HEADONLY_EPOCHS": 0, | |
| "N_EPOCHS": 10, | |
| "BASE_LR": 1e-4, | |
| "BS": 32, | |
| "HIDDEN": 128, | |
| "DROPOUT": 0.01, | |
| # "WANDB": True, | |
| "REPO_ID": "TristanKE/RemainingLifespanPredictionFaces", | |
| # "REPO_ID": "TristanKE/RemainingLifespanPredictionWholeImgs", | |
| "DINO_MODEL": "dinov2_vitl14_reg", | |
| # "DINO_MODEL": "dinov2_vitg14_reg", #the largest model, but also the slowest | |
| "DINO_DIM": 1024, | |
| # "DINO_DIM": 1536, #for the larger model | |
| } | |
| # if cfg["WANDB"]: | |
| # wandb.init(project="mortpred", config=cfg) | |
| torch.manual_seed(1) | |
| ds = LifespanDataset(repo_id=cfg["REPO_ID"],transform=imgtransform) | |
| test_sz = int(0.2 * len(ds)) | |
| train_sz = len(ds) - test_sz | |
| train_ds, test_ds = random_split(ds, [train_sz, test_sz]) | |
| train_dataset = Subset(ds, train_ds.indices) | |
| test_dataset = Subset(ds, test_ds.indices) | |
| train_loader = DataLoader(train_dataset, batch_size=cfg["BS"], shuffle=True, num_workers=4) | |
| test_loader = DataLoader(test_dataset, batch_size=cfg["BS"], shuffle=False, num_workers=4) | |
| # Load the model and move it to the GPU | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| dino_backbone = torch.hub.load("facebookresearch/dinov2", cfg["DINO_MODEL"]).to(device) | |
| model = DinoRegressionHeteroImages(dino_backbone, hidden_dim=cfg["HIDDEN"], dropout=cfg["DROPOUT"], dino_dim=cfg["DINO_DIM"]).to(device) | |
| optimizer = optim.Adam(model.parameters(), lr=cfg["BASE_LR"]) | |
| scheduler = LambdaLR(optimizer, lambda e: cosine_schedule(e, cfg["N_EPOCHS"])) | |
| best_test_mae = float("inf") | |
| for epoch in range(cfg["N_EPOCHS"]): | |
| # Train | |
| model.train() | |
| tr_nll, tr_mae = 0.0, 0.0 | |
| t0 = time.time() | |
| for imgs, tgt in train_loader: | |
| imgs, tgt = imgs.to(device), tgt.to(device) | |
| optimizer.zero_grad() | |
| mu, logvar = model(imgs) | |
| loss = heteroscedastic_nll(tgt, mu, logvar) | |
| loss.backward() | |
| optimizer.step() | |
| tr_nll += loss.item() * imgs.size(0) | |
| tr_mae += torch.abs(mu.detach() - tgt).sum().item() | |
| # if cfg["WANDB"]: | |
| # wandb.log({ | |
| # "train_nll": loss.item(), | |
| # "train_mae": torch.abs(mu.detach() - tgt).mean().item() * ds.lifespan_std, | |
| # "train_std": torch.exp(0.5 * logvar).mean().item() * ds.lifespan_std, | |
| # }) | |
| tr_nll /= train_sz | |
| tr_mae = tr_mae / train_sz * ds.lifespan_std | |
| # Evaluate | |
| model.eval() | |
| te_nll, te_mae = 0.0, 0.0 | |
| with torch.no_grad(): | |
| for imgs, tgt in test_loader: | |
| imgs, tgt = imgs.to(device), tgt.to(device) | |
| mu, logvar = model(imgs) | |
| nll = heteroscedastic_nll(tgt, mu, logvar) | |
| te_nll += nll.item() * imgs.size(0) | |
| te_mae += torch.abs(mu - tgt).sum().item() | |
| te_nll /= test_sz | |
| te_mae = te_mae / test_sz * ds.lifespan_std | |
| print(f"Epoch {epoch+1}/{cfg['N_EPOCHS']} | {time.time()-t0:.1f}s | NLL tr {tr_nll:.3f} / te {te_nll:.3f} | MAE(te) {te_mae:.2f} yrs") | |
| # if cfg["WANDB"]: | |
| # wandb.log({ | |
| # "train_nll": tr_nll, | |
| # "test_nll": te_nll, | |
| # "test_mae_yrs": te_mae, | |
| # "lr": scheduler.get_last_lr()[0], | |
| # }) | |
| scheduler.step() | |
| # save best | |
| if te_mae < best_test_mae: | |
| best_test_mae = te_mae | |
| if not os.path.exists("savedmodels"): | |
| os.makedirs("savedmodels") | |
| torch.save(model.state_dict(), f"savedmodels/dino_finetuned_faces_l1_{cfg['DINO_DIM']}_best.pth") | |
| print(f"\tNew best model saved (test MAE {te_mae:.3f})") | |
| # if cfg["WANDB"]: | |
| # wandb.finish() | |