Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import base64 | |
| import io | |
| import os | |
| import numpy as np | |
| from pathlib import Path | |
| from plonk.pipe import PlonkPipeline | |
| import random | |
| # Global variable to store the model | |
| model = None | |
| # Real PLONK predictions for production deployment | |
| MOCK_MODE = False # Set to True for testing with mock data | |
| def load_plonk_model(): | |
| """ | |
| Load the PLONK model. | |
| """ | |
| global model | |
| if model is None: | |
| print("Loading PLONK_YFCC model...") | |
| model = PlonkPipeline(model_path="nicolas-dufour/PLONK_YFCC") | |
| print("Model loaded successfully!") | |
| return model | |
| def mock_plonk_prediction(): | |
| """ | |
| Mock PLONK prediction - returns realistic coordinates | |
| Used only when MOCK_MODE = True | |
| """ | |
| # Sample realistic coordinates from major cities/regions | |
| mock_locations = [ | |
| (40.7128, -74.0060), # New York | |
| (34.0522, -118.2437), # Los Angeles | |
| (51.5074, -0.1278), # London | |
| (48.8566, 2.3522), # Paris | |
| (35.6762, 139.6503), # Tokyo | |
| (37.7749, -122.4194), # San Francisco | |
| (41.8781, -87.6298), # Chicago | |
| (25.7617, -80.1918), # Miami | |
| (45.5017, -73.5673), # Montreal | |
| (52.5200, 13.4050), # Berlin | |
| (-33.8688, 151.2093), # Sydney | |
| (19.4326, -99.1332), # Mexico City | |
| ] | |
| # Add some randomness to make it more realistic | |
| base_lat, base_lon = random.choice(mock_locations) | |
| lat = base_lat + random.uniform(-2, 2) # Add noise within ~200km | |
| lon = base_lon + random.uniform(-2, 2) | |
| return lat, lon | |
| def real_plonk_prediction(image): | |
| """ | |
| Real PLONK prediction using the diff-plonk package | |
| Now generates 32 samples for better uncertainty estimation | |
| """ | |
| from plonk.pipe import PlonkPipeline | |
| import numpy as np | |
| # Load the model (do this once at startup, not per request) | |
| if not hasattr(gr, 'plonk_pipeline'): | |
| print("Loading PLONK model...") | |
| gr.plonk_pipeline = PlonkPipeline(model_path="nicolas-dufour/PLONK_YFCC") | |
| print("PLONK model loaded successfully!") | |
| # Get 32 predictions for uncertainty estimation | |
| predicted_gps = gr.plonk_pipeline(image, batch_size=32, cfg=2.0, num_steps=32) | |
| # Convert to numpy for easier processing | |
| predictions = predicted_gps.cpu().numpy() # Shape: (32, 2) | |
| # Calculate statistics | |
| mean_lat = float(np.mean(predictions[:, 0])) | |
| mean_lon = float(np.mean(predictions[:, 1])) | |
| std_lat = float(np.std(predictions[:, 0])) | |
| std_lon = float(np.std(predictions[:, 1])) | |
| # Calculate uncertainty radius (approximate) | |
| uncertainty_km = np.sqrt(std_lat**2 + std_lon**2) * 111.32 # Rough conversion to km | |
| return mean_lat, mean_lon, uncertainty_km, len(predictions) | |
| def predict_location(image): | |
| """ | |
| Main prediction function for Gradio interface | |
| """ | |
| try: | |
| if image is None: | |
| return "Please upload an image." | |
| # Ensure RGB format | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Get prediction (mock or real) | |
| if MOCK_MODE: | |
| lat, lon = mock_plonk_prediction() | |
| confidence = "mock" | |
| uncertainty_km = None | |
| num_samples = 1 | |
| note = " (Mock prediction for testing)" | |
| else: | |
| lat, lon, uncertainty_km, num_samples = real_plonk_prediction(image) | |
| confidence = "high" | |
| note = f" (Real PLONK prediction, {num_samples} samples)" | |
| # Format the result | |
| uncertainty_text = f"\n**Uncertainty:** ±{uncertainty_km:.1f} km" if uncertainty_km is not None else "" | |
| result = f"""🗺️ **Predicted Location**{note} | |
| **Latitude:** {lat:.6f} | |
| **Longitude:** {lon:.6f}{uncertainty_text} | |
| **Confidence:** {confidence} | |
| **Samples:** {num_samples} | |
| **Mode:** {'🧪 Mock Testing' if MOCK_MODE else '🚀 Production'} | |
| 🌍 *This prediction estimates where the image was taken based on visual content.* | |
| """ | |
| return result | |
| except Exception as e: | |
| return f"❌ Error processing image: {str(e)}" | |
| def predict_location_json(image): | |
| """ | |
| JSON API function for programmatic access | |
| Returns structured data instead of formatted text | |
| """ | |
| try: | |
| if image is None: | |
| return { | |
| "error": "No image provided", | |
| "status": "error" | |
| } | |
| # Ensure RGB format | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Get prediction (mock or real) | |
| if MOCK_MODE: | |
| lat, lon = mock_plonk_prediction() | |
| confidence = "mock" | |
| uncertainty_km = None | |
| num_samples = 1 | |
| else: | |
| lat, lon, uncertainty_km, num_samples = real_plonk_prediction(image) | |
| confidence = "high" | |
| result = { | |
| "status": "success", | |
| "mode": "mock" if MOCK_MODE else "production", | |
| "predicted_location": { | |
| "latitude": round(lat, 6), | |
| "longitude": round(lon, 6) | |
| }, | |
| "confidence": confidence, | |
| "samples": num_samples, | |
| "note": "This is a mock prediction for testing" if MOCK_MODE else f"Real PLONK prediction using {num_samples} samples" | |
| } | |
| # Add uncertainty info if available | |
| if uncertainty_km is not None: | |
| result["uncertainty_km"] = round(uncertainty_km, 1) | |
| return result | |
| except Exception as e: | |
| return { | |
| "error": str(e), | |
| "status": "error" | |
| } | |
| # Create the Gradio interface | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| title="🗺️ PLONK: Around the World in 80 Timesteps" | |
| ) as demo: | |
| # Header | |
| gr.Markdown(""" | |
| # 🗺️ PLONK: Around the World in 80 Timesteps | |
| A generative approach to global visual geolocation. Upload an image and PLONK will predict where it was taken! | |
| This uses the PLONK model concept from the paper: *"Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation"* | |
| **Current Mode:** {'🧪 Mock Testing' if MOCK_MODE else '🚀 Production'} - Real PLONK model predictions with 32 samples for uncertainty estimation. | |
| **Configuration:** Guidance Scale = 2.0, Samples = 32, Steps = 32 | |
| """) | |
| with gr.Tab("🖼️ Image Upload"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| label="Upload an image", | |
| type="pil", | |
| sources=["upload", "webcam", "clipboard"] | |
| ) | |
| predict_btn = gr.Button( | |
| "🔍 Predict Location", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| clear_btn = gr.ClearButton( | |
| components=[image_input], | |
| value="🗑️ Clear" | |
| ) | |
| with gr.Column(scale=1): | |
| output_text = gr.Markdown( | |
| label="Prediction Result", | |
| value="Upload an image and click 'Predict Location' to see results." | |
| ) | |
| with gr.Tab("📡 API Information"): | |
| gr.Markdown(f""" | |
| ## 🔗 API Access | |
| This Space provides both web interface and programmatic API access: | |
| ### **REST API Endpoint** | |
| ``` | |
| POST https://kylanoconnor-plonk-geolocation.hf.space/api/predict | |
| ``` | |
| ### **Python Example** | |
| ```python | |
| import requests | |
| # For API access | |
| response = requests.post( | |
| "https://kylanoconnor-plonk-geolocation.hf.space/api/predict", | |
| files={{"file": open("image.jpg", "rb")}} | |
| ) | |
| result = response.json() | |
| print(f"Location: {{result['data']['latitude']}}, {{result['data']['longitude']}}") | |
| ``` | |
| ### **cURL Example** | |
| ```bash | |
| curl -X POST \\ | |
| -F "data=@image.jpg" \\ | |
| "https://kylanoconnor-plonk-geolocation.hf.space/api/predict" | |
| ``` | |
| ### **Gradio Client (Python)** | |
| ```python | |
| from gradio_client import Client | |
| client = Client("kylanoconnor/plonk-geolocation") | |
| result = client.predict("path/to/image.jpg", api_name="/predict") | |
| print(result) | |
| ``` | |
| ### **JavaScript/Node.js** | |
| ```javascript | |
| const formData = new FormData(); | |
| formData.append('data', imageFile); | |
| const response = await fetch( | |
| 'https://kylanoconnor-plonk-geolocation.hf.space/api/predict', | |
| {{ | |
| method: 'POST', | |
| body: formData | |
| }} | |
| ); | |
| const result = await response.json(); | |
| console.log('Location:', result.data); | |
| ``` | |
| **Current Status:** {'🧪 Mock Mode - Returns realistic test coordinates' if MOCK_MODE else '🚀 Production Mode - Real PLONK predictions with 32 samples'} | |
| **Response Format:** | |
| - Latitude/Longitude coordinates | |
| - Uncertainty estimation (±km radius) | |
| - Number of samples used (32 for production) | |
| - Prediction confidence metrics | |
| **Rate Limits:** Standard Hugging Face Spaces limits apply | |
| **CORS:** Enabled for web integration | |
| """) | |
| with gr.Tab("ℹ️ About"): | |
| gr.Markdown(f""" | |
| ## About PLONK | |
| PLONK is a generative approach to global visual geolocation that uses diffusion models to predict where images were taken. | |
| **Paper:** [Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation](https://arxiv.org/abs/2412.06781) | |
| **Authors:** Nicolas Dufour, David Picard, Vicky Kalogeiton, Loic Landrieu | |
| **Original Code:** https://github.com/nicolas-dufour/plonk | |
| ### Current Deployment | |
| - **Mode:** {'Mock Testing' if MOCK_MODE else 'Production'} | |
| - **Model:** {'Simulated predictions for API testing' if MOCK_MODE else 'Real PLONK model inference'} | |
| - **Response Format:** Structured JSON + formatted text | |
| - **API:** Fully functional REST endpoints | |
| ### Production Deployment | |
| This Space is running with the real PLONK model using: | |
| - **Model:** nicolas-dufour/PLONK_YFCC | |
| - **Dataset:** YFCC-100M | |
| - **Inference:** CFG=2.0, 32 samples, 32 timesteps for high quality predictions | |
| - **Uncertainty:** Statistical analysis across 32 predictions for reliability estimation | |
| ### Available Models | |
| - `nicolas-dufour/PLONK_YFCC` - YFCC-100M dataset | |
| - `nicolas-dufour/PLONK_iNaturalist` - iNaturalist dataset | |
| - `nicolas-dufour/PLONK_OSV_5M` - OpenStreetView-5M dataset | |
| """) | |
| # Event handlers | |
| predict_btn.click( | |
| fn=predict_location, | |
| inputs=[image_input], | |
| outputs=[output_text], | |
| api_name="predict" # This enables API access at /api/predict | |
| ) | |
| # Hidden API function for JSON responses | |
| predict_json = gr.Interface( | |
| fn=predict_location_json, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.JSON(), | |
| api_name="predict_json" # Available at /api/predict_json | |
| ) | |
| # Add examples if available | |
| try: | |
| examples = [ | |
| ["demo/examples/condor.jpg"], | |
| ["demo/examples/Kilimanjaro.jpg"], | |
| ["demo/examples/pigeon.png"] | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=image_input, | |
| outputs=output_text, | |
| fn=predict_location, | |
| cache_examples=True | |
| ) | |
| except: | |
| pass # Examples not available, skip | |
| if __name__ == "__main__": | |
| # For local testing | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_api=True | |
| ) |