Spaces:
Build error
Build error
| import os | |
| import tempfile | |
| from pathlib import Path | |
| # Set memory optimization environment variables | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
| os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16' | |
| import gradio as gr | |
| import datetime | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import cartopy.crs as ccrs | |
| import cartopy.feature as cfeature | |
| import matplotlib.tri as tri | |
| from anemoi.inference.runners.simple import SimpleRunner | |
| from ecmwf.opendata import Client as OpendataClient | |
| import earthkit.data as ekd | |
| import earthkit.regrid as ekr | |
| import matplotlib.animation as animation | |
| from functools import lru_cache | |
| import hashlib | |
| import pickle | |
| import json | |
| from typing import List, Dict, Any | |
| import logging | |
| import xarray as xr | |
| import pandas as pd | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Define parameters (updating to match notebook.py) | |
| PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"] | |
| PARAM_SOIL = ["vsw", "sot"] | |
| PARAM_PL = ["gh", "t", "u", "v", "w", "q"] | |
| LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50] | |
| SOIL_LEVELS = [1, 2] | |
| DEFAULT_DATE = OpendataClient().latest() | |
| # First organize variables into categories | |
| VARIABLE_GROUPS = { | |
| "Surface Variables": { | |
| "10u": "10m U Wind Component", | |
| "10v": "10m V Wind Component", | |
| "2d": "2m Dewpoint Temperature", | |
| "2t": "2m Temperature", | |
| "msl": "Mean Sea Level Pressure", | |
| "skt": "Skin Temperature", | |
| "sp": "Surface Pressure", | |
| "tcw": "Total Column Water", | |
| "lsm": "Land-Sea Mask", | |
| "z": "Surface Geopotential", | |
| "slor": "Slope of Sub-gridscale Orography", | |
| "sdor": "Standard Deviation of Orography", | |
| }, | |
| "Soil Variables": { | |
| "stl1": "Soil Temperature Level 1", | |
| "stl2": "Soil Temperature Level 2", | |
| "swvl1": "Soil Water Volume Level 1", | |
| "swvl2": "Soil Water Volume Level 2", | |
| }, | |
| "Pressure Level Variables": {} # Will fill this dynamically | |
| } | |
| # Add pressure level variables dynamically | |
| for var in ["t", "u", "v", "w", "q", "z"]: | |
| var_name = { | |
| "t": "Temperature", | |
| "u": "U Wind Component", | |
| "v": "V Wind Component", | |
| "w": "Vertical Velocity", | |
| "q": "Specific Humidity", | |
| "z": "Geopotential" | |
| }[var] | |
| for level in LEVELS: | |
| var_id = f"{var}_{level}" | |
| VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa" | |
| # Load the model once at startup | |
| MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda") # Default to CUDA | |
| # Create and set custom temp directory | |
| TEMP_DIR = Path("./gradio_temp") | |
| TEMP_DIR.mkdir(exist_ok=True) | |
| os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR) | |
| # Add these cache-related functions after the MODEL initialization | |
| def get_cache_key(date: datetime.datetime, params: List[str], levellist: List[int]) -> str: | |
| """Create a unique cache key based on the request parameters""" | |
| key_parts = [ | |
| date.isoformat(), | |
| ",".join(sorted(params)), | |
| ",".join(str(x) for x in sorted(levellist)) if levellist else "no_levels" | |
| ] | |
| key_string = "_".join(key_parts) | |
| cache_key = hashlib.md5(key_string.encode()).hexdigest() | |
| logger.info(f"Generated cache key: {cache_key} for {key_string}") | |
| return cache_key | |
| def get_cache_path(cache_key: str) -> Path: | |
| """Get the path to the cache file""" | |
| return TEMP_DIR / "data_cache" / f"{cache_key}.pkl" | |
| def save_to_cache(cache_key: str, data: Dict[str, Any]) -> None: | |
| """Save data to disk cache""" | |
| cache_file = get_cache_path(cache_key) | |
| try: | |
| with open(cache_file, 'wb') as f: | |
| pickle.dump(data, f) | |
| logger.info(f"Successfully saved data to cache: {cache_file}") | |
| except Exception as e: | |
| logger.error(f"Failed to save to cache: {e}") | |
| def load_from_cache(cache_key: str) -> Dict[str, Any]: | |
| """Load data from disk cache""" | |
| cache_file = get_cache_path(cache_key) | |
| if cache_file.exists(): | |
| try: | |
| with open(cache_file, 'rb') as f: | |
| data = pickle.load(f) | |
| logger.info(f"Successfully loaded data from cache: {cache_file}") | |
| return data | |
| except Exception as e: | |
| logger.error(f"Failed to load from cache: {e}") | |
| cache_file.unlink(missing_ok=True) | |
| logger.info(f"No cache file found: {cache_file}") | |
| return None | |
| # Modify the get_open_data function to use caching | |
| def get_cached_data(date_str: str, param_tuple: tuple, levelist_tuple: tuple) -> Dict[str, Any]: | |
| """Memory cache wrapper for get_open_data""" | |
| return get_open_data_impl( | |
| datetime.datetime.fromisoformat(date_str), | |
| list(param_tuple), | |
| list(levelist_tuple) if levelist_tuple else [] | |
| ) | |
| def get_open_data(param: List[str], levelist: List[int] = None) -> Dict[str, Any]: | |
| """Main function to get data with caching""" | |
| if levelist is None: | |
| levelist = [] | |
| # Try disk cache first (more persistent than memory cache) | |
| cache_key = get_cache_key(DEFAULT_DATE, param, levelist) | |
| logger.info(f"Checking cache for key: {cache_key}") | |
| cached_data = load_from_cache(cache_key) | |
| if cached_data is not None: | |
| logger.info(f"Cache hit for {cache_key}") | |
| return cached_data | |
| # If not in cache, download and process the data | |
| logger.info(f"Cache miss for {cache_key}, downloading fresh data") | |
| fields = get_open_data_impl(DEFAULT_DATE, param, levelist) | |
| # Save to disk cache | |
| save_to_cache(cache_key, fields) | |
| return fields | |
| def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List[int]) -> Dict[str, Any]: | |
| """Implementation of data download and processing""" | |
| fields = {} | |
| myiterable = [date - datetime.timedelta(hours=6), date] | |
| logger.info(f"Downloading data for dates: {myiterable}") | |
| for current_date in myiterable: | |
| logger.info(f"Fetching data for {current_date}") | |
| data = ekd.from_source("ecmwf-open-data", date=current_date, param=param, levelist=levelist) | |
| for f in data: | |
| assert f.to_numpy().shape == (721, 1440) | |
| values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1) | |
| values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"}) | |
| name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param") | |
| if name not in fields: | |
| fields[name] = [] | |
| fields[name].append(values) | |
| # Create a single matrix for each parameter | |
| for param, values in fields.items(): | |
| fields[param] = np.stack(values) | |
| return fields | |
| def plot_forecast(state, selected_variable): | |
| logger.info(f"Plotting forecast for {selected_variable} at time {state['date']}") | |
| # Setup the figure and axis | |
| fig = plt.figure(figsize=(15, 8)) | |
| ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0)) | |
| # Get the coordinates | |
| latitudes, longitudes = state["latitudes"], state["longitudes"] | |
| fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes) | |
| triangulation = tri.Triangulation(fixed_lons, latitudes) | |
| # Get the values | |
| values = state["fields"][selected_variable] | |
| logger.info(f"Value range: min={np.min(values):.2f}, max={np.max(values):.2f}") | |
| # Set map features | |
| ax.set_global() | |
| ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree()) | |
| ax.coastlines(resolution='50m') | |
| ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5) | |
| ax.gridlines(draw_labels=True) | |
| # Create contour plot | |
| contour = ax.tricontourf(triangulation, values, | |
| levels=20, transform=ccrs.PlateCarree(), | |
| cmap='RdBu_r') | |
| # Add colorbar | |
| plt.colorbar(contour, ax=ax, orientation='horizontal', pad=0.05) | |
| # Format the date string | |
| forecast_time = state["date"] | |
| if isinstance(forecast_time, str): | |
| forecast_time = datetime.datetime.fromisoformat(forecast_time) | |
| time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC") | |
| # Get variable description | |
| var_desc = None | |
| for group in VARIABLE_GROUPS.values(): | |
| if selected_variable in group: | |
| var_desc = group[selected_variable] | |
| break | |
| var_name = var_desc if var_desc else selected_variable | |
| ax.set_title(f"{var_name} - {time_str}") | |
| # Save as PNG | |
| temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.png") | |
| plt.savefig(temp_file, bbox_inches='tight', dpi=100) | |
| plt.close() | |
| return temp_file | |
| def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[str, Any]: | |
| # Get all required fields | |
| fields = {} | |
| logger.info(f"Starting forecast for lead_time: {lead_time} hours") | |
| # Get surface fields | |
| logger.info("Getting surface fields...") | |
| fields.update(get_open_data(param=PARAM_SFC)) | |
| # Get soil fields and rename them | |
| logger.info("Getting soil fields...") | |
| soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS) | |
| mapping = { | |
| 'sot_1': 'stl1', 'sot_2': 'stl2', | |
| 'vsw_1': 'swvl1', 'vsw_2': 'swvl2' | |
| } | |
| for k, v in soil.items(): | |
| fields[mapping[k]] = v | |
| # Get pressure level fields | |
| logger.info("Getting pressure level fields...") | |
| fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS)) | |
| # Convert geopotential height to geopotential | |
| for level in LEVELS: | |
| gh = fields.pop(f"gh_{level}") | |
| fields[f"z_{level}"] = gh * 9.80665 | |
| input_state = dict(date=date, fields=fields) | |
| # Use the global model instance | |
| global MODEL | |
| if device != MODEL.device: | |
| MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device) | |
| # Run the model and get the final state | |
| final_state = None | |
| for state in MODEL.run(input_state=input_state, lead_time=lead_time): | |
| logger.info(f"\n😀 date={state['date']} latitudes={state['latitudes'].shape} " | |
| f"longitudes={state['longitudes'].shape} fields={len(state['fields'])}") | |
| # Log a few example variables to show we have all fields | |
| for var in ['2t', 'msl', 't_1000', 'z_850']: | |
| if var in state['fields']: | |
| values = state['fields'][var] | |
| logger.info(f" {var:<6} shape={values.shape} " | |
| f"min={np.min(values):.6f} " | |
| f"max={np.max(values):.6f}") | |
| final_state = state | |
| logger.info(f"Final state contains {len(final_state['fields'])} variables") | |
| return final_state | |
| def get_available_variables(state): | |
| """Get available variables from the state and organize them into groups""" | |
| available_vars = set(state['fields'].keys()) | |
| # Create dropdown choices only for available variables | |
| choices = [] | |
| for group_name, variables in VARIABLE_GROUPS.items(): | |
| group_vars = [(f"{desc} ({var_id})", var_id) | |
| for var_id, desc in variables.items() | |
| if var_id in available_vars] | |
| if group_vars: # Only add group if it has available variables | |
| choices.append((f"── {group_name} ──", None)) | |
| choices.extend(group_vars) | |
| return choices | |
| def save_forecast_data(state, format='json'): | |
| """Save forecast data in specified format""" | |
| if state is None: | |
| raise ValueError("No forecast data available. Please run a forecast first.") | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| forecast_time = state['date'].strftime("%Y%m%d_%H") if isinstance(state['date'], datetime.datetime) else state['date'] | |
| # Use forecasts directory for all outputs | |
| output_dir = TEMP_DIR / "forecasts" | |
| if format == 'json': | |
| # Create a JSON-serializable dictionary | |
| data = { | |
| 'metadata': { | |
| 'forecast_date': forecast_time, | |
| 'export_date': datetime.datetime.now().isoformat(), | |
| 'total_points': len(state['latitudes']), | |
| 'total_variables': len(state['fields']) | |
| }, | |
| 'coordinates': { | |
| 'latitudes': state['latitudes'].tolist(), | |
| 'longitudes': state['longitudes'].tolist() | |
| }, | |
| 'fields': { | |
| var_name: { | |
| 'values': values.tolist(), | |
| 'statistics': { | |
| 'min': float(np.min(values)), | |
| 'max': float(np.max(values)), | |
| 'mean': float(np.mean(values)), | |
| 'std': float(np.std(values)) | |
| } | |
| } | |
| for var_name, values in state['fields'].items() | |
| } | |
| } | |
| output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.json" | |
| with open(output_file, 'w') as f: | |
| json.dump(data, f, indent=2) | |
| return str(output_file) | |
| elif format == 'netcdf': | |
| # Create an xarray Dataset | |
| data_vars = {} | |
| coords = { | |
| 'point': np.arange(len(state['latitudes'])), | |
| 'latitude': ('point', state['latitudes']), | |
| 'longitude': ('point', state['longitudes']), | |
| } | |
| # Add each field as a variable | |
| for var_name, values in state['fields'].items(): | |
| data_vars[var_name] = (['point'], values) | |
| # Create the dataset | |
| ds = xr.Dataset( | |
| data_vars=data_vars, | |
| coords=coords, | |
| attrs={ | |
| 'forecast_date': forecast_time, | |
| 'export_date': datetime.datetime.now().isoformat(), | |
| 'description': 'AIFS Weather Forecast Data' | |
| } | |
| ) | |
| output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.nc" | |
| ds.to_netcdf(output_file) | |
| return str(output_file) | |
| elif format == 'csv': | |
| # Create a DataFrame with lat/lon and all variables | |
| df = pd.DataFrame({ | |
| 'latitude': state['latitudes'], | |
| 'longitude': state['longitudes'] | |
| }) | |
| # Add each field as a column | |
| for var_name, values in state['fields'].items(): | |
| df[var_name] = values | |
| output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.csv" | |
| df.to_csv(output_file, index=False) | |
| return str(output_file) | |
| else: | |
| raise ValueError(f"Unsupported format: {format}") | |
| # Create dropdown choices with groups | |
| DROPDOWN_CHOICES = [] | |
| for group_name, variables in VARIABLE_GROUPS.items(): | |
| # Add group separator | |
| DROPDOWN_CHOICES.append((f"── {group_name} ──", None)) | |
| # Add variables in this group | |
| for var_id, desc in sorted(variables.items()): | |
| DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id)) | |
| def update_interface(): | |
| with gr.Blocks(css=""" | |
| .centered-header { | |
| text-align: center; | |
| margin-bottom: 20px; | |
| } | |
| .subtitle { | |
| font-size: 1.2em; | |
| line-height: 1.5; | |
| margin: 20px 0; | |
| } | |
| .footer { | |
| text-align: center; | |
| padding: 20px; | |
| margin-top: 20px; | |
| border-top: 1px solid #eee; | |
| } | |
| """) as demo: | |
| forecast_state = gr.State(None) | |
| # Header section | |
| gr.Markdown(f""" | |
| # AIFS Weather Forecast | |
| <div class="subtitle"> | |
| Interactive visualization of ECMWF AIFS weather forecasts.<br> | |
| Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),<br> | |
| select how many hours ahead you want to forecast and which meteorological variable to visualize. | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| lead_time = gr.Slider( | |
| minimum=6, | |
| maximum=48, | |
| step=6, | |
| value=12, | |
| label="Forecast Hours Ahead" | |
| ) | |
| # Start with the original DROPDOWN_CHOICES | |
| variable = gr.Dropdown( | |
| choices=DROPDOWN_CHOICES, # Use original choices at startup | |
| value="2t", | |
| label="Select Variable to Plot" | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear") | |
| run_btn = gr.Button("Run Forecast", variant="primary") | |
| download_nc = gr.Button("Download Forecast (NetCDF)") | |
| download_output = gr.File(label="Download Output") | |
| with gr.Column(scale=2): | |
| forecast_output = gr.Image() | |
| def run_and_store(lead_time): | |
| """Run forecast and store state""" | |
| forecast_state = run_forecast(DEFAULT_DATE, lead_time, "cuda") | |
| plot = plot_forecast(forecast_state, "2t") # Default to 2t | |
| return forecast_state, plot | |
| def update_plot_from_state(forecast_state, variable): | |
| """Update plot using stored state""" | |
| if forecast_state is None or variable is None: | |
| return None | |
| try: | |
| return plot_forecast(forecast_state, variable) | |
| except KeyError as e: | |
| logger.error(f"Variable {variable} not found in state: {e}") | |
| return None | |
| def clear(): | |
| """Clear everything""" | |
| return [None, None, 12, "2t"] | |
| def save_netcdf(forecast_state): | |
| """Save forecast data as NetCDF""" | |
| if forecast_state is None: | |
| raise ValueError("No forecast data available. Please run a forecast first.") | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| forecast_time = forecast_state['date'].strftime("%Y%m%d_%H") if isinstance(forecast_state['date'], datetime.datetime) else forecast_state['date'] | |
| # Create an xarray Dataset | |
| data_vars = {} | |
| coords = { | |
| 'point': np.arange(len(forecast_state['latitudes'])), | |
| 'latitude': ('point', forecast_state['latitudes']), | |
| 'longitude': ('point', forecast_state['longitudes']), | |
| } | |
| # Add each field as a variable | |
| for var_name, values in forecast_state['fields'].items(): | |
| data_vars[var_name] = (['point'], values) | |
| # Create the dataset | |
| ds = xr.Dataset( | |
| data_vars=data_vars, | |
| coords=coords, | |
| attrs={ | |
| 'forecast_date': forecast_time, | |
| 'export_date': datetime.datetime.now().isoformat(), | |
| 'description': 'AIFS Weather Forecast Data' | |
| } | |
| ) | |
| output_file = TEMP_DIR / "forecasts" / f"forecast_{forecast_time}_{timestamp}.nc" | |
| ds.to_netcdf(output_file) | |
| return str(output_file) | |
| # Connect the components | |
| run_btn.click( | |
| fn=run_and_store, | |
| inputs=[lead_time], | |
| outputs=[forecast_state, forecast_output] | |
| ) | |
| variable.change( | |
| fn=update_plot_from_state, | |
| inputs=[forecast_state, variable], | |
| outputs=forecast_output | |
| ) | |
| clear_btn.click( | |
| fn=clear, | |
| inputs=[], | |
| outputs=[forecast_state, forecast_output, lead_time, variable] | |
| ) | |
| download_nc.click( | |
| fn=save_netcdf, | |
| inputs=[forecast_state], | |
| outputs=[download_output] | |
| ) | |
| return demo | |
| # Create and launch the interface | |
| demo = update_interface() | |
| demo.launch() | |
| def setup_directories(): | |
| """Create necessary directories with .keep files""" | |
| # Define all required directories | |
| directories = { | |
| TEMP_DIR / "data_cache": "Cache directory for downloaded weather data", | |
| TEMP_DIR / "forecasts": "Directory for forecast outputs (plots and data files)", | |
| } | |
| # Create directories and .keep files | |
| for directory, description in directories.items(): | |
| directory.mkdir(parents=True, exist_ok=True) | |
| keep_file = directory / ".keep" | |
| if not keep_file.exists(): | |
| keep_file.write_text(f"# {description}\n# This file ensures the directory is tracked in git\n") | |
| logger.info(f"Created directory and .keep file: {directory}") | |
| # Call it during initialization | |
| setup_directories() | |
| def cleanup_old_files(): | |
| """Remove old temporary and cache files""" | |
| current_time = datetime.datetime.now().timestamp() | |
| # Clean up forecast files (1 hour old) | |
| forecast_dir = TEMP_DIR / "forecasts" | |
| for file in forecast_dir.glob("*.*"): | |
| if file.name == ".keep": | |
| continue | |
| if current_time - file.stat().st_mtime > 3600: | |
| logger.info(f"Removing old forecast file: {file}") | |
| file.unlink(missing_ok=True) | |
| # Clean up cache files (24 hours old) | |
| cache_dir = TEMP_DIR / "data_cache" | |
| for file in cache_dir.glob("*.pkl"): | |
| if file.name == ".keep": | |
| continue | |
| if current_time - file.stat().st_mtime > 86400: | |
| logger.info(f"Removing old cache file: {file}") | |
| file.unlink(missing_ok=True) | |