Spaces:
Runtime error
Runtime error
| import torch | |
| from plonk.models.pretrained_models import Plonk | |
| from plonk.models.samplers.riemannian_flow_sampler import riemannian_flow_sampler | |
| from plonk.models.postprocessing import CartesiantoGPS | |
| from plonk.models.schedulers import ( | |
| SigmoidScheduler, | |
| LinearScheduler, | |
| CosineScheduler, | |
| ) | |
| from plonk.models.preconditioning import DDPMPrecond | |
| from torchvision import transforms | |
| from transformers import CLIPProcessor, CLIPVisionModel | |
| from plonk.utils.image_processing import CenterCrop | |
| import numpy as np | |
| from plonk.utils.manifolds import Sphere | |
| from torch.func import jacrev, vmap, vjp | |
| from torchdiffeq import odeint | |
| from tqdm import tqdm | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| MODELS = { | |
| "nicolas-dufour/PLONK_YFCC": {"emb_name": "dinov2"}, | |
| "nicolas-dufour/PLONK_OSV_5M": { | |
| "emb_name": "street_clip", | |
| }, | |
| "nicolas-dufour/PLONK_iNaturalist": { | |
| "emb_name": "dinov2", | |
| }, | |
| } | |
| def scheduler_fn( | |
| scheduler_type: str, start: float, end: float, tau: float, clip_min: float = 1e-9 | |
| ): | |
| if scheduler_type == "sigmoid": | |
| return SigmoidScheduler(start, end, tau, clip_min) | |
| elif scheduler_type == "cosine": | |
| return CosineScheduler(start, end, tau, clip_min) | |
| elif scheduler_type == "linear": | |
| return LinearScheduler(clip_min=clip_min) | |
| else: | |
| raise ValueError(f"Scheduler type {scheduler_type} not supported") | |
| class DinoV2FeatureExtractor: | |
| def __init__(self, device=device): | |
| super().__init__() | |
| self.device = device | |
| self.emb_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg") | |
| self.emb_model.eval() | |
| self.emb_model.to(self.device) | |
| self.augmentation = transforms.Compose( | |
| [ | |
| CenterCrop(ratio="1:1"), | |
| transforms.Resize( | |
| 336, interpolation=transforms.InterpolationMode.BICUBIC | |
| ), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) | |
| ), | |
| ] | |
| ) | |
| def __call__(self, batch): | |
| embs = [] | |
| with torch.no_grad(): | |
| for img in batch["img"]: | |
| emb = self.emb_model( | |
| self.augmentation(img).unsqueeze(0).to(self.device) | |
| ).squeeze(0) | |
| embs.append(emb) | |
| batch["emb"] = torch.stack(embs) | |
| return batch | |
| class StreetClipFeatureExtractor: | |
| def __init__(self, device=device): | |
| self.device = device | |
| self.emb_model = CLIPVisionModel.from_pretrained("geolocal/StreetCLIP").to( | |
| device | |
| ) | |
| self.processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") | |
| def __call__(self, batch): | |
| inputs = self.processor(images=batch["img"], return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.emb_model(**inputs) | |
| embeddings = outputs.last_hidden_state[:, 0] | |
| batch["emb"] = embeddings | |
| return batch | |
| def load_prepocessing(model_name, dtype=torch.float32): | |
| if MODELS[model_name]["emb_name"] == "dinov2": | |
| return DinoV2FeatureExtractor() | |
| elif MODELS[model_name]["emb_name"] == "street_clip": | |
| return StreetClipFeatureExtractor() | |
| else: | |
| raise ValueError(f"Embedding model {MODELS[model_name]['emb_name']} not found") | |
| # Helper functions adapted from plonk/models/module.py | |
| # for likelihood computation | |
| def div_fn(u): | |
| """Accepts a function u:R^D -> R^D.""" | |
| J = jacrev(u, argnums=0) | |
| return lambda x, y: torch.trace(J(x, y).squeeze(0)) | |
| def output_and_div(vecfield, x, y, v=None): | |
| if v is None: | |
| dx = vecfield(x, y) | |
| div = vmap(div_fn(vecfield))(x, y) | |
| else: | |
| vecfield_x = lambda x: vecfield(x, y) | |
| dx, vjpfunc = vjp(vecfield_x, x) | |
| vJ = vjpfunc(v)[0] | |
| div = torch.sum(vJ * v, dim=-1) | |
| return dx, div | |
| def _gps_degrees_to_cartesian(gps_coords_deg, device): | |
| """Converts GPS coordinates (latitude, longitude) in degrees to Cartesian coordinates.""" | |
| if not isinstance(gps_coords_deg, np.ndarray): | |
| gps_coords_deg = np.array(gps_coords_deg) | |
| if gps_coords_deg.ndim == 1: | |
| gps_coords_deg = gps_coords_deg[np.newaxis, :] | |
| lat_rad = np.radians(gps_coords_deg[:, 0]) | |
| lon_rad = np.radians(gps_coords_deg[:, 1]) | |
| x = np.cos(lat_rad) * np.cos(lon_rad) | |
| y = np.cos(lat_rad) * np.sin(lon_rad) | |
| z = np.sin(lat_rad) | |
| cartesian_coords = np.stack([x, y, z], axis=-1) | |
| return torch.tensor(cartesian_coords, dtype=torch.float32, device=device) | |
| class PlonkPipeline: | |
| """ | |
| The PlonkPipeline class is designed to perform geolocation prediction from images using a pre-trained PLONK model. | |
| It integrates various components such as feature extractors, samplers, and coordinate transformations to predict locations. | |
| Initialization: | |
| PlonkPipeline( | |
| model_path, | |
| scheduler="sigmoid", | |
| scheduler_start=-7, | |
| scheduler_end=3, | |
| scheduler_tau=1.0, | |
| device="cuda", | |
| ) | |
| Parameters: | |
| model_path (str): Path to the pre-trained PLONK model. | |
| scheduler (str): The scheduler type to use. Options are "sigmoid", "cosine", "linear". Default is "sigmoid". | |
| scheduler_start (float): Start value for the scheduler. Default is -7. | |
| scheduler_end (float): End value for the scheduler. Default is 3. | |
| scheduler_tau (float): Tau value for the scheduler. Default is 1.0. | |
| device (str): Device to run the model on. Default is "cuda". | |
| Methods: | |
| model(*args, **kwargs): | |
| Runs the preconditioning on the network with the provided arguments. | |
| __call__(...): | |
| Predicts geolocation coordinates from input images. | |
| Parameters: | |
| images: Input images to predict locations for. | |
| batch_size (int, optional): Batch size for processing. | |
| x_N (torch.Tensor, optional): Initial noise tensor. If not provided, it is generated. | |
| num_steps (int, optional): Number of steps for the sampler. | |
| scheduler (callable, optional): Custom scheduler function. If not provided, the default scheduler is used. | |
| cfg (float): Classifier-free guidance scale. Default is 0. | |
| generator (torch.Generator, optional): Random number generator. | |
| Returns: | |
| torch.Tensor: Predicted latitude and longitude coordinates. | |
| compute_likelihood(...): | |
| Computes the exact log-likelihood of observing the given coordinates for the given images. | |
| Parameters: | |
| images: Input images (PIL Image or list of PIL Images). Optional if emb is provided. | |
| coordinates: Target GPS coordinates (latitude, longitude) in degrees. | |
| emb: Pre-computed embeddings. If provided, images will be ignored. | |
| cfg (float): Classifier-free guidance scale. Default is 0 (no guidance). | |
| rademacher (bool): Whether to use Rademacher estimator for divergence. Default is False. | |
| atol (float): Absolute tolerance for ODE solver. Default is 1e-5. | |
| rtol (float): Relative tolerance for ODE solver. Default is 1e-5. | |
| normalize_logp (bool): Whether to normalize the log-likelihood by log(2) * dim. Default is True. | |
| compute_likelihood_grid(...): | |
| Computes the likelihood of an image over a global grid of coordinates. | |
| Parameters: | |
| image: Input PIL Image. | |
| grid_resolution_deg (float): The resolution of the grid in degrees. Default is 10 degrees. | |
| batch_size (int): How many grid points to process in each batch. Adjust based on available memory. Default is 1024. | |
| cfg (float): Classifier-free guidance scale passed to compute_likelihood. Default is 0. | |
| Returns: | |
| tuple: (latitude_grid, longitude_grid, likelihood_grid) | |
| - latitude_grid (np.ndarray): 1D array of latitudes. | |
| - longitude_grid (np.ndarray): 1D array of longitudes. | |
| - likelihood_grid (np.ndarray): 2D array of log-likelihoods corresponding to the lat/lon grid. | |
| compute_localizability(...): | |
| Computes the localizability of an image. We use importance sampling by sampling by the model and not the grid to have a more accurate estimate. | |
| Parameters: | |
| image: Input PIL Image. | |
| atol (float): Absolute tolerance for ODE solver. Default is 1e-5. | |
| rtol (float): Relative tolerance for ODE solver. Default is 1e-5. | |
| number_monte_carlo_samples (int): How many samples to use for importance sampling. Default is 1024. | |
| Returns: | |
| torch.Tensor: Localizability of the image. | |
| Example Usage: | |
| pipe = PlonkPipeline( | |
| "path/to/plonk/model", | |
| ) | |
| pipe.to("cuda") | |
| coordinates = pipe( | |
| images, | |
| batch_size=32 | |
| ) | |
| likelihood = pipe.compute_likelihood( | |
| images, | |
| coordinates, | |
| cfg=0, | |
| rademacher=False, | |
| ) | |
| localizability = pipe.compute_localizability( | |
| image, | |
| number_monte_carlo_samples=1024, | |
| ) | |
| """ | |
| def __init__( | |
| self, | |
| model_path, | |
| scheduler="sigmoid", | |
| scheduler_start=-7, | |
| scheduler_end=3, | |
| scheduler_tau=1.0, | |
| device=device, | |
| ): | |
| self.network = Plonk.from_pretrained(model_path).to(device) | |
| self.network.requires_grad_(False).eval() | |
| assert scheduler in [ | |
| "sigmoid", | |
| "cosine", | |
| "linear", | |
| ], f"Scheduler {scheduler} not supported" | |
| self.scheduler = scheduler_fn( | |
| scheduler, scheduler_start, scheduler_end, scheduler_tau | |
| ) | |
| self.cond_preprocessing = load_prepocessing(model_name=model_path) | |
| self.postprocessing = CartesiantoGPS() | |
| self.sampler = riemannian_flow_sampler | |
| self.model_path = model_path | |
| self.preconditioning = DDPMPrecond() | |
| self.device = device | |
| # Add manifold | |
| self.manifold = Sphere() | |
| self.input_dim = 3 # Assuming 3D Cartesian coordinates for sphere | |
| def model(self, *args, **kwargs): | |
| return self.preconditioning(self.network, *args, **kwargs) | |
| def __call__( | |
| self, | |
| images, | |
| batch_size=None, | |
| x_N=None, | |
| num_steps=None, | |
| scheduler=None, | |
| cfg=0, | |
| generator=None, | |
| ): | |
| """Sample from the model given conditioning. | |
| Args: | |
| images: Conditioning input (image or list of images) | |
| batch_size: Number of samples to generate (inferred from cond if not provided) | |
| x_N: Initial noise tensor (generated if not provided) | |
| num_steps: Number of sampling steps (uses default if not provided) | |
| sampler: Custom sampler function (uses default if not provided) | |
| scheduler: Custom scheduler function (uses default if not provided) | |
| cfg: Classifier-free guidance scale (default 15) | |
| generator: Random number generator | |
| Returns: | |
| Sampled GPS coordinates after postprocessing | |
| """ | |
| # Set up batch size and initial noise | |
| shape = [3] | |
| if not isinstance(images, list): | |
| images = [images] | |
| if x_N is None: | |
| if batch_size is None: | |
| if isinstance(images, list): | |
| batch_size = len(images) | |
| else: | |
| batch_size = 1 | |
| x_N = torch.randn( | |
| batch_size, *shape, device=self.device, generator=generator | |
| ) | |
| else: | |
| x_N = x_N.to(self.device) | |
| if x_N.ndim == 3: | |
| x_N = x_N.unsqueeze(0) | |
| batch_size = x_N.shape[0] | |
| # Set up batch with conditioning | |
| batch = {"y": x_N} | |
| batch["img"] = images | |
| batch = self.cond_preprocessing(batch) | |
| if len(images) > 1: | |
| assert len(images) == batch_size | |
| else: | |
| batch["emb"] = batch["emb"].repeat(batch_size, 1) | |
| # Use default sampler/scheduler if not provided | |
| sampler = self.sampler | |
| if scheduler is None: | |
| scheduler = self.scheduler | |
| # Sample from model | |
| if num_steps is None: | |
| output = sampler( | |
| self.model, | |
| batch, | |
| conditioning_keys="emb", | |
| scheduler=scheduler, | |
| cfg_rate=cfg, | |
| generator=generator, | |
| ) | |
| else: | |
| output = sampler( | |
| self.model, | |
| batch, | |
| conditioning_keys="emb", | |
| scheduler=scheduler, | |
| num_steps=num_steps, | |
| cfg_rate=cfg, | |
| generator=generator, | |
| ) | |
| # Apply postprocessing and return | |
| output = self.postprocessing(output) | |
| # To degrees | |
| output = np.degrees(output.detach().cpu().numpy()) | |
| return output | |
| def compute_likelihood( | |
| self, | |
| images=None, | |
| coordinates=None, | |
| emb=None, | |
| cfg=0, | |
| rademacher=False, | |
| atol=1e-6, | |
| rtol=1e-6, | |
| normalize_logp=True, | |
| ): | |
| """ | |
| Computes the exact log-likelihood of observing the given coordinates for the given images. | |
| Args: | |
| images: Input images (PIL Image or list of PIL Images). Optional if emb is provided. | |
| coordinates: Target GPS coordinates (latitude, longitude) in degrees. | |
| Can be a list of pairs, numpy array (N, 2), or tensor (N, 2). | |
| emb: Pre-computed embeddings. If provided, images will be ignored. | |
| cfg (float): Classifier-free guidance scale. Default is 0 (no guidance). | |
| rademacher (bool): Whether to use Rademacher estimator for divergence. Default is False. | |
| atol (float): Absolute tolerance for ODE solver. Default is 1e-5. | |
| rtol (float): Relative tolerance for ODE solver. Default is 1e-5. | |
| normalize_logp (bool): Whether to normalize the log-likelihood by log(2) * dim. Default is True. | |
| Returns: | |
| torch.Tensor: Log-likelihood values for each input pair (image, coordinate). | |
| """ | |
| nfe = [0] # Counter for number of function evaluations | |
| # 1. Get embeddings either from images or directly from emb parameter | |
| if emb is not None: | |
| # Use provided embeddings directly | |
| if isinstance(emb, torch.Tensor): | |
| batch = {"emb": emb.to(self.device)} | |
| else: | |
| raise TypeError("emb must be a torch.Tensor") | |
| else: | |
| # Process images to get embeddings | |
| if not isinstance(images, list): | |
| images = [images] | |
| batch = {"img": images} | |
| batch = self.cond_preprocessing(batch) # Adds 'emb' key | |
| # 2. Preprocess coordinates (GPS degrees -> Cartesian) | |
| x_1 = _gps_degrees_to_cartesian(coordinates, self.device) | |
| if x_1.shape[0] != batch["emb"].shape[0]: | |
| if x_1.shape[0] == 1: | |
| # Repeat coordinates if only one is provided for multiple images | |
| x_1 = x_1.repeat(batch["emb"].shape[0], 1) | |
| elif batch["emb"].shape[0] == 1: | |
| # Repeat embedding if only one image is provided for multiple coords | |
| batch["emb"] = batch["emb"].repeat(x_1.shape[0], 1) | |
| else: | |
| raise ValueError( | |
| f"Batch size mismatch between images ({batch['emb'].shape[0]}) and coordinates ({x_1.shape[0]})" | |
| ) | |
| # Ensure correct shapes for ODE solver | |
| if x_1.ndim == 1: | |
| x_1 = x_1.unsqueeze(0) | |
| if batch["emb"].ndim == 1: | |
| batch["emb"] = batch["emb"].unsqueeze(0) | |
| with torch.inference_mode(mode=False): # Enable grads for jacobian calculation | |
| # Define the ODE function | |
| def odefunc(t, tensor): | |
| nfe[0] += 1 | |
| t = t.to(tensor) | |
| gamma = self.scheduler(t) | |
| x = tensor[..., : self.input_dim] | |
| y = batch["emb"] # Conditioning | |
| def vecfield(x_vf, y_vf): | |
| batch_vecfield = { | |
| "y": x_vf, | |
| "emb": y_vf, | |
| "gamma": gamma.reshape(-1), | |
| } | |
| if cfg > 0: | |
| # Apply classifier-free guidance | |
| batch_vecfield_uncond = { | |
| "y": x_vf, | |
| "emb": torch.zeros_like(y_vf), # Null condition | |
| "gamma": gamma.reshape(-1), | |
| } | |
| model_output_cond = self.model(batch_vecfield) | |
| model_output_uncond = self.model(batch_vecfield_uncond) | |
| model_output = model_output_cond + cfg * ( | |
| model_output_cond - model_output_uncond | |
| ) | |
| else: | |
| # Unconditional or naturally conditioned score | |
| model_output = self.model(batch_vecfield) | |
| # Assuming 'flow_matching' interpolant based on sampler used | |
| d_gamma = self.scheduler.derivative(t).reshape(-1, 1) | |
| return d_gamma * model_output | |
| if rademacher: | |
| v = torch.randint_like(x, 2) * 2 - 1 | |
| else: | |
| v = None | |
| dx, div = output_and_div(vecfield, x, y, v=v) | |
| div = div.reshape(-1, 1) | |
| del t, x | |
| return torch.cat([dx, div], dim=-1) | |
| # 3. Solve the ODE | |
| state1 = torch.cat([x_1, torch.zeros_like(x_1[..., :1])], dim=-1) | |
| # Note: Using standard ODEINT here. For strict Riemannian integration, | |
| # a manifold-aware solver might be needed, but this follows the | |
| # structure from DiffGeolocalizer.compute_exact_loglikelihood more closely. | |
| with torch.no_grad(): | |
| state0 = odeint( | |
| odefunc, | |
| state1, | |
| t=torch.linspace(0, 1.0, 2).to(x_1.device), | |
| atol=atol, | |
| rtol=rtol, | |
| method="dopri5", | |
| options={"min_step": 1e-5}, | |
| )[ | |
| -1 | |
| ] # Get the state at t=0 | |
| x_0, logdetjac = state0[..., : self.input_dim], state0[..., -1] | |
| # Project final point onto the manifold (optional but good practice) | |
| x_0 = self.manifold.projx(x_0) | |
| # 4. Compute log probability | |
| # Log prob of base distribution (Gaussian projected onto sphere approx) | |
| logp0 = self.manifold.base_logprob(x_0) | |
| # Change of variables formula: log p(x_1) = log p(x_0) + log |det J| | |
| logp1 = logp0 + logdetjac | |
| # Optional: Normalize by log(2) * dim for bits per dimension | |
| if normalize_logp: | |
| logp1 = logp1 / (self.input_dim * np.log(2)) | |
| print(f"Likelihood NFE: {nfe[0]}") # Print number of function evaluations | |
| return logp1 | |
| def compute_likelihood_grid( | |
| self, | |
| image, | |
| grid_resolution_deg=10, | |
| batch_size=1024, | |
| cfg=0, | |
| ): | |
| """ | |
| Computes the likelihood of an image over a global grid of coordinates. | |
| Args: | |
| image: Input PIL Image. | |
| grid_resolution_deg (float): The resolution of the grid in degrees. | |
| Default is 10 degrees. | |
| batch_size (int): How many grid points to process in each batch. | |
| Adjust based on available memory. Default is 1024. | |
| cfg (float): Classifier-free guidance scale passed to compute_likelihood. | |
| Default is 0. | |
| Returns: | |
| tuple: (latitude_grid, longitude_grid, likelihood_grid) | |
| - latitude_grid (np.ndarray): 1D array of latitudes. | |
| - longitude_grid (np.ndarray): 1D array of longitudes. | |
| - likelihood_grid (np.ndarray): 2D array of log-likelihoods | |
| corresponding to the lat/lon grid. | |
| """ | |
| # 1. Generate the grid | |
| latitudes = np.arange(-90, 90 + grid_resolution_deg, grid_resolution_deg) | |
| longitudes = np.arange(-180, 180 + grid_resolution_deg, grid_resolution_deg) | |
| lon_grid, lat_grid = np.meshgrid(longitudes, latitudes) | |
| # Flatten the grid for processing | |
| all_coordinates = np.vstack([lat_grid.ravel(), lon_grid.ravel()]).T | |
| num_points = all_coordinates.shape[0] | |
| print( | |
| f"Computing likelihood over a {latitudes.size}x{longitudes.size} grid ({num_points} points)..." | |
| ) | |
| emb = self.cond_preprocessing({"img": [image]})["emb"] | |
| # 2. Process in batches | |
| all_likelihoods = [] | |
| for i in tqdm( | |
| range(0, num_points, batch_size), desc="Computing Likelihood Grid" | |
| ): | |
| coord_batch = all_coordinates[i : i + batch_size] | |
| # Compute likelihood for the batch | |
| likelihood_batch = self.compute_likelihood( | |
| emb=emb, | |
| coordinates=coord_batch, | |
| cfg=cfg, | |
| rademacher=False, # Using exact divergence is better for grid | |
| ) | |
| all_likelihoods.append(likelihood_batch.detach().cpu().numpy()) | |
| # 3. Combine and reshape results | |
| likelihood_flat = np.concatenate(all_likelihoods, axis=0) | |
| likelihood_grid = likelihood_flat.reshape(lat_grid.shape) | |
| # Return grid definition and likelihood values | |
| return latitudes, longitudes, likelihood_grid | |
| def compute_localizability( | |
| self, | |
| image, | |
| atol=1e-6, | |
| rtol=1e-6, | |
| number_monte_carlo_samples=1024, | |
| ): | |
| """ | |
| Computes the localizability of an image. We use importance sampling by sampling by the model and not the grid to have a more accurate estimate. | |
| Args: | |
| image: Input PIL Image. | |
| atol (float): Absolute tolerance for ODE solver. Default is 1e-5. | |
| rtol (float): Relative tolerance for ODE solver. Default is 1e-5. | |
| """ | |
| samples = self(image, batch_size=number_monte_carlo_samples) | |
| emb = self.cond_preprocessing({"img": [image]})["emb"] | |
| localizability = self.compute_likelihood( | |
| emb=emb, | |
| coordinates=samples, | |
| atol=atol, | |
| rtol=rtol, | |
| normalize_logp=False, | |
| ) # importance sampling of likelihood | |
| return localizability.mean() / (4 * torch.pi * np.log(2)) | |
| def to(self, device): | |
| self.network.to(device) | |
| self.postprocessing.to(device) | |
| self.device = torch.device(device) | |
| return self | |