Spaces:
Build error
Build error
add variables
Browse files- .gitignore +29 -0
- app.py +244 -80
- gradio_temp/.keep +0 -0
.gitignore
CHANGED
|
@@ -1,3 +1,32 @@
|
|
| 1 |
aifs-single-mse-1.0.ckpt
|
| 2 |
flagged/
|
| 3 |
gradio_temp/*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
aifs-single-mse-1.0.ckpt
|
| 2 |
flagged/
|
| 3 |
gradio_temp/*
|
| 4 |
+
|
| 5 |
+
# Ignore all files in temp directories except .keep
|
| 6 |
+
gradio_temp/data_cache/*
|
| 7 |
+
!gradio_temp/data_cache/.keep
|
| 8 |
+
|
| 9 |
+
gradio_temp/forecasts/*
|
| 10 |
+
!gradio_temp/forecasts/.keep
|
| 11 |
+
|
| 12 |
+
# Python cache files
|
| 13 |
+
__pycache__/
|
| 14 |
+
*.py[cod]
|
| 15 |
+
*$py.class
|
| 16 |
+
|
| 17 |
+
# Environment directories
|
| 18 |
+
.env
|
| 19 |
+
.venv
|
| 20 |
+
env/
|
| 21 |
+
venv/
|
| 22 |
+
ENV/
|
| 23 |
+
|
| 24 |
+
# IDE directories
|
| 25 |
+
.idea/
|
| 26 |
+
.vscode/
|
| 27 |
+
|
| 28 |
+
# Jupyter Notebook
|
| 29 |
+
.ipynb_checkpoints
|
| 30 |
+
|
| 31 |
+
# Logs
|
| 32 |
+
*.log
|
app.py
CHANGED
|
@@ -23,6 +23,8 @@ import pickle
|
|
| 23 |
import json
|
| 24 |
from typing import List, Dict, Any
|
| 25 |
import logging
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# Configure logging
|
| 28 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
@@ -71,7 +73,7 @@ for var in ["t", "u", "v", "w", "q", "z"]:
|
|
| 71 |
"q": "Specific Humidity",
|
| 72 |
"z": "Geopotential"
|
| 73 |
}[var]
|
| 74 |
-
|
| 75 |
for level in LEVELS:
|
| 76 |
var_id = f"{var}_{level}"
|
| 77 |
VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"
|
|
@@ -99,9 +101,7 @@ def get_cache_key(date: datetime.datetime, params: List[str], levellist: List[in
|
|
| 99 |
|
| 100 |
def get_cache_path(cache_key: str) -> Path:
|
| 101 |
"""Get the path to the cache file"""
|
| 102 |
-
|
| 103 |
-
cache_dir.mkdir(exist_ok=True)
|
| 104 |
-
return cache_dir / f"{cache_key}.pkl"
|
| 105 |
|
| 106 |
def save_to_cache(cache_key: str, data: Dict[str, Any]) -> None:
|
| 107 |
"""Save data to disk cache"""
|
|
@@ -142,23 +142,23 @@ def get_open_data(param: List[str], levelist: List[int] = None) -> Dict[str, Any
|
|
| 142 |
"""Main function to get data with caching"""
|
| 143 |
if levelist is None:
|
| 144 |
levelist = []
|
| 145 |
-
|
| 146 |
# Try disk cache first (more persistent than memory cache)
|
| 147 |
cache_key = get_cache_key(DEFAULT_DATE, param, levelist)
|
| 148 |
logger.info(f"Checking cache for key: {cache_key}")
|
| 149 |
-
|
| 150 |
cached_data = load_from_cache(cache_key)
|
| 151 |
if cached_data is not None:
|
| 152 |
logger.info(f"Cache hit for {cache_key}")
|
| 153 |
return cached_data
|
| 154 |
-
|
| 155 |
# If not in cache, download and process the data
|
| 156 |
logger.info(f"Cache miss for {cache_key}, downloading fresh data")
|
| 157 |
fields = get_open_data_impl(DEFAULT_DATE, param, levelist)
|
| 158 |
-
|
| 159 |
# Save to disk cache
|
| 160 |
save_to_cache(cache_key, fields)
|
| 161 |
-
|
| 162 |
return fields
|
| 163 |
|
| 164 |
def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List[int]) -> Dict[str, Any]:
|
|
@@ -166,7 +166,7 @@ def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List
|
|
| 166 |
fields = {}
|
| 167 |
myiterable = [date - datetime.timedelta(hours=6), date]
|
| 168 |
logger.info(f"Downloading data for dates: {myiterable}")
|
| 169 |
-
|
| 170 |
for current_date in myiterable:
|
| 171 |
logger.info(f"Fetching data for {current_date}")
|
| 172 |
data = ekd.from_source("ecmwf-open-data", date=current_date, param=param, levelist=levelist)
|
|
@@ -178,50 +178,50 @@ def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List
|
|
| 178 |
if name not in fields:
|
| 179 |
fields[name] = []
|
| 180 |
fields[name].append(values)
|
| 181 |
-
|
| 182 |
# Create a single matrix for each parameter
|
| 183 |
for param, values in fields.items():
|
| 184 |
fields[param] = np.stack(values)
|
| 185 |
-
|
| 186 |
return fields
|
| 187 |
|
| 188 |
def plot_forecast(state, selected_variable):
|
| 189 |
logger.info(f"Plotting forecast for {selected_variable} at time {state['date']}")
|
| 190 |
-
|
| 191 |
# Setup the figure and axis
|
| 192 |
fig = plt.figure(figsize=(15, 8))
|
| 193 |
ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
|
| 194 |
-
|
| 195 |
# Get the coordinates
|
| 196 |
latitudes, longitudes = state["latitudes"], state["longitudes"]
|
| 197 |
fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
|
| 198 |
triangulation = tri.Triangulation(fixed_lons, latitudes)
|
| 199 |
-
|
| 200 |
# Get the values
|
| 201 |
values = state["fields"][selected_variable]
|
| 202 |
logger.info(f"Value range: min={np.min(values):.2f}, max={np.max(values):.2f}")
|
| 203 |
-
|
| 204 |
# Set map features
|
| 205 |
ax.set_global()
|
| 206 |
ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
|
| 207 |
ax.coastlines(resolution='50m')
|
| 208 |
ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
|
| 209 |
ax.gridlines(draw_labels=True)
|
| 210 |
-
|
| 211 |
# Create contour plot
|
| 212 |
contour = ax.tricontourf(triangulation, values,
|
| 213 |
levels=20, transform=ccrs.PlateCarree(),
|
| 214 |
cmap='RdBu_r')
|
| 215 |
-
|
| 216 |
# Add colorbar
|
| 217 |
plt.colorbar(contour, ax=ax, orientation='horizontal', pad=0.05)
|
| 218 |
-
|
| 219 |
# Format the date string
|
| 220 |
forecast_time = state["date"]
|
| 221 |
if isinstance(forecast_time, str):
|
| 222 |
forecast_time = datetime.datetime.fromisoformat(forecast_time)
|
| 223 |
time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
|
| 224 |
-
|
| 225 |
# Get variable description
|
| 226 |
var_desc = None
|
| 227 |
for group in VARIABLE_GROUPS.values():
|
|
@@ -229,25 +229,25 @@ def plot_forecast(state, selected_variable):
|
|
| 229 |
var_desc = group[selected_variable]
|
| 230 |
break
|
| 231 |
var_name = var_desc if var_desc else selected_variable
|
| 232 |
-
|
| 233 |
ax.set_title(f"{var_name} - {time_str}")
|
| 234 |
-
|
| 235 |
# Save as PNG
|
| 236 |
temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.png")
|
| 237 |
plt.savefig(temp_file, bbox_inches='tight', dpi=100)
|
| 238 |
plt.close()
|
| 239 |
-
|
| 240 |
return temp_file
|
| 241 |
|
| 242 |
def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[str, Any]:
|
| 243 |
# Get all required fields
|
| 244 |
fields = {}
|
| 245 |
logger.info(f"Starting forecast for lead_time: {lead_time} hours")
|
| 246 |
-
|
| 247 |
# Get surface fields
|
| 248 |
logger.info("Getting surface fields...")
|
| 249 |
fields.update(get_open_data(param=PARAM_SFC))
|
| 250 |
-
|
| 251 |
# Get soil fields and rename them
|
| 252 |
logger.info("Getting soil fields...")
|
| 253 |
soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
|
|
@@ -257,29 +257,29 @@ def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[s
|
|
| 257 |
}
|
| 258 |
for k, v in soil.items():
|
| 259 |
fields[mapping[k]] = v
|
| 260 |
-
|
| 261 |
# Get pressure level fields
|
| 262 |
logger.info("Getting pressure level fields...")
|
| 263 |
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
|
| 264 |
-
|
| 265 |
# Convert geopotential height to geopotential
|
| 266 |
for level in LEVELS:
|
| 267 |
gh = fields.pop(f"gh_{level}")
|
| 268 |
fields[f"z_{level}"] = gh * 9.80665
|
| 269 |
-
|
| 270 |
input_state = dict(date=date, fields=fields)
|
| 271 |
-
|
| 272 |
# Use the global model instance
|
| 273 |
global MODEL
|
| 274 |
if device != MODEL.device:
|
| 275 |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
|
| 276 |
-
|
| 277 |
# Run the model and get the final state
|
| 278 |
final_state = None
|
| 279 |
for state in MODEL.run(input_state=input_state, lead_time=lead_time):
|
| 280 |
logger.info(f"\n😀 date={state['date']} latitudes={state['latitudes'].shape} "
|
| 281 |
f"longitudes={state['longitudes'].shape} fields={len(state['fields'])}")
|
| 282 |
-
|
| 283 |
# Log a few example variables to show we have all fields
|
| 284 |
for var in ['2t', 'msl', 't_1000', 'z_850']:
|
| 285 |
if var in state['fields']:
|
|
@@ -287,29 +287,130 @@ def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[s
|
|
| 287 |
logger.info(f" {var:<6} shape={values.shape} "
|
| 288 |
f"min={np.min(values):.6f} "
|
| 289 |
f"max={np.max(values):.6f}")
|
| 290 |
-
|
| 291 |
final_state = state
|
| 292 |
-
|
| 293 |
logger.info(f"Final state contains {len(final_state['fields'])} variables")
|
| 294 |
return final_state
|
| 295 |
|
| 296 |
def get_available_variables(state):
|
| 297 |
"""Get available variables from the state and organize them into groups"""
|
| 298 |
available_vars = set(state['fields'].keys())
|
| 299 |
-
|
| 300 |
# Create dropdown choices only for available variables
|
| 301 |
choices = []
|
| 302 |
for group_name, variables in VARIABLE_GROUPS.items():
|
| 303 |
-
group_vars = [(f"{desc} ({var_id})", var_id)
|
| 304 |
-
for var_id, desc in variables.items()
|
| 305 |
if var_id in available_vars]
|
| 306 |
-
|
| 307 |
if group_vars: # Only add group if it has available variables
|
| 308 |
choices.append((f"── {group_name} ──", None))
|
| 309 |
choices.extend(group_vars)
|
| 310 |
-
|
| 311 |
return choices
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
def update_interface():
|
| 314 |
with gr.Blocks(css="""
|
| 315 |
.centered-header {
|
|
@@ -328,7 +429,18 @@ def update_interface():
|
|
| 328 |
border-top: 1px solid #eee;
|
| 329 |
}
|
| 330 |
""") as demo:
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
with gr.Row():
|
| 334 |
with gr.Column(scale=1):
|
|
@@ -339,90 +451,101 @@ def update_interface():
|
|
| 339 |
value=12,
|
| 340 |
label="Forecast Hours Ahead"
|
| 341 |
)
|
|
|
|
| 342 |
variable = gr.Dropdown(
|
| 343 |
-
choices=
|
| 344 |
-
value=
|
| 345 |
label="Select Variable to Plot"
|
| 346 |
)
|
| 347 |
with gr.Row():
|
| 348 |
clear_btn = gr.Button("Clear")
|
| 349 |
run_btn = gr.Button("Run Forecast", variant="primary")
|
| 350 |
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
download_nc = gr.Button("Download NetCDF")
|
| 354 |
|
| 355 |
with gr.Column(scale=2):
|
| 356 |
forecast_output = gr.Image()
|
| 357 |
|
| 358 |
def run_and_store(lead_time):
|
| 359 |
"""Run forecast and store state"""
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
choices = get_available_variables(state)
|
| 364 |
-
|
| 365 |
-
# Select first real variable as default
|
| 366 |
-
default_var = next((var_id for _, var_id in choices if var_id is not None), None)
|
| 367 |
-
|
| 368 |
-
# Generate initial plot
|
| 369 |
-
plot = plot_forecast(state, default_var) if default_var else None
|
| 370 |
-
|
| 371 |
-
return [state, gr.Dropdown(choices=choices), default_var, plot]
|
| 372 |
|
| 373 |
-
def update_plot_from_state(
|
| 374 |
"""Update plot using stored state"""
|
| 375 |
-
if
|
| 376 |
return None
|
| 377 |
try:
|
| 378 |
-
return plot_forecast(
|
| 379 |
except KeyError as e:
|
| 380 |
logger.error(f"Variable {variable} not found in state: {e}")
|
| 381 |
return None
|
| 382 |
|
| 383 |
def clear():
|
| 384 |
"""Clear everything"""
|
| 385 |
-
return [None, None,
|
| 386 |
-
|
| 387 |
-
def save_json(state):
|
| 388 |
-
if state is None:
|
| 389 |
-
return None
|
| 390 |
-
return save_forecast_data(state, 'json')
|
| 391 |
|
| 392 |
-
def save_netcdf(
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
|
| 397 |
# Connect the components
|
| 398 |
run_btn.click(
|
| 399 |
fn=run_and_store,
|
| 400 |
inputs=[lead_time],
|
| 401 |
-
outputs=[
|
| 402 |
)
|
| 403 |
|
| 404 |
variable.change(
|
| 405 |
fn=update_plot_from_state,
|
| 406 |
-
inputs=[
|
| 407 |
outputs=forecast_output
|
| 408 |
)
|
| 409 |
|
| 410 |
clear_btn.click(
|
| 411 |
fn=clear,
|
| 412 |
inputs=[],
|
| 413 |
-
outputs=[
|
| 414 |
-
)
|
| 415 |
-
|
| 416 |
-
download_json.click(
|
| 417 |
-
fn=save_json,
|
| 418 |
-
inputs=[state],
|
| 419 |
-
outputs=gr.File()
|
| 420 |
)
|
| 421 |
|
| 422 |
download_nc.click(
|
| 423 |
fn=save_netcdf,
|
| 424 |
-
inputs=[
|
| 425 |
-
outputs=
|
| 426 |
)
|
| 427 |
|
| 428 |
return demo
|
|
@@ -430,3 +553,44 @@ def update_interface():
|
|
| 430 |
# Create and launch the interface
|
| 431 |
demo = update_interface()
|
| 432 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
import json
|
| 24 |
from typing import List, Dict, Any
|
| 25 |
import logging
|
| 26 |
+
import xarray as xr
|
| 27 |
+
import pandas as pd
|
| 28 |
|
| 29 |
# Configure logging
|
| 30 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
| 73 |
"q": "Specific Humidity",
|
| 74 |
"z": "Geopotential"
|
| 75 |
}[var]
|
| 76 |
+
|
| 77 |
for level in LEVELS:
|
| 78 |
var_id = f"{var}_{level}"
|
| 79 |
VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"
|
|
|
|
| 101 |
|
| 102 |
def get_cache_path(cache_key: str) -> Path:
|
| 103 |
"""Get the path to the cache file"""
|
| 104 |
+
return TEMP_DIR / "data_cache" / f"{cache_key}.pkl"
|
|
|
|
|
|
|
| 105 |
|
| 106 |
def save_to_cache(cache_key: str, data: Dict[str, Any]) -> None:
|
| 107 |
"""Save data to disk cache"""
|
|
|
|
| 142 |
"""Main function to get data with caching"""
|
| 143 |
if levelist is None:
|
| 144 |
levelist = []
|
| 145 |
+
|
| 146 |
# Try disk cache first (more persistent than memory cache)
|
| 147 |
cache_key = get_cache_key(DEFAULT_DATE, param, levelist)
|
| 148 |
logger.info(f"Checking cache for key: {cache_key}")
|
| 149 |
+
|
| 150 |
cached_data = load_from_cache(cache_key)
|
| 151 |
if cached_data is not None:
|
| 152 |
logger.info(f"Cache hit for {cache_key}")
|
| 153 |
return cached_data
|
| 154 |
+
|
| 155 |
# If not in cache, download and process the data
|
| 156 |
logger.info(f"Cache miss for {cache_key}, downloading fresh data")
|
| 157 |
fields = get_open_data_impl(DEFAULT_DATE, param, levelist)
|
| 158 |
+
|
| 159 |
# Save to disk cache
|
| 160 |
save_to_cache(cache_key, fields)
|
| 161 |
+
|
| 162 |
return fields
|
| 163 |
|
| 164 |
def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List[int]) -> Dict[str, Any]:
|
|
|
|
| 166 |
fields = {}
|
| 167 |
myiterable = [date - datetime.timedelta(hours=6), date]
|
| 168 |
logger.info(f"Downloading data for dates: {myiterable}")
|
| 169 |
+
|
| 170 |
for current_date in myiterable:
|
| 171 |
logger.info(f"Fetching data for {current_date}")
|
| 172 |
data = ekd.from_source("ecmwf-open-data", date=current_date, param=param, levelist=levelist)
|
|
|
|
| 178 |
if name not in fields:
|
| 179 |
fields[name] = []
|
| 180 |
fields[name].append(values)
|
| 181 |
+
|
| 182 |
# Create a single matrix for each parameter
|
| 183 |
for param, values in fields.items():
|
| 184 |
fields[param] = np.stack(values)
|
| 185 |
+
|
| 186 |
return fields
|
| 187 |
|
| 188 |
def plot_forecast(state, selected_variable):
|
| 189 |
logger.info(f"Plotting forecast for {selected_variable} at time {state['date']}")
|
| 190 |
+
|
| 191 |
# Setup the figure and axis
|
| 192 |
fig = plt.figure(figsize=(15, 8))
|
| 193 |
ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
|
| 194 |
+
|
| 195 |
# Get the coordinates
|
| 196 |
latitudes, longitudes = state["latitudes"], state["longitudes"]
|
| 197 |
fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
|
| 198 |
triangulation = tri.Triangulation(fixed_lons, latitudes)
|
| 199 |
+
|
| 200 |
# Get the values
|
| 201 |
values = state["fields"][selected_variable]
|
| 202 |
logger.info(f"Value range: min={np.min(values):.2f}, max={np.max(values):.2f}")
|
| 203 |
+
|
| 204 |
# Set map features
|
| 205 |
ax.set_global()
|
| 206 |
ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
|
| 207 |
ax.coastlines(resolution='50m')
|
| 208 |
ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
|
| 209 |
ax.gridlines(draw_labels=True)
|
| 210 |
+
|
| 211 |
# Create contour plot
|
| 212 |
contour = ax.tricontourf(triangulation, values,
|
| 213 |
levels=20, transform=ccrs.PlateCarree(),
|
| 214 |
cmap='RdBu_r')
|
| 215 |
+
|
| 216 |
# Add colorbar
|
| 217 |
plt.colorbar(contour, ax=ax, orientation='horizontal', pad=0.05)
|
| 218 |
+
|
| 219 |
# Format the date string
|
| 220 |
forecast_time = state["date"]
|
| 221 |
if isinstance(forecast_time, str):
|
| 222 |
forecast_time = datetime.datetime.fromisoformat(forecast_time)
|
| 223 |
time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
|
| 224 |
+
|
| 225 |
# Get variable description
|
| 226 |
var_desc = None
|
| 227 |
for group in VARIABLE_GROUPS.values():
|
|
|
|
| 229 |
var_desc = group[selected_variable]
|
| 230 |
break
|
| 231 |
var_name = var_desc if var_desc else selected_variable
|
| 232 |
+
|
| 233 |
ax.set_title(f"{var_name} - {time_str}")
|
| 234 |
+
|
| 235 |
# Save as PNG
|
| 236 |
temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.png")
|
| 237 |
plt.savefig(temp_file, bbox_inches='tight', dpi=100)
|
| 238 |
plt.close()
|
| 239 |
+
|
| 240 |
return temp_file
|
| 241 |
|
| 242 |
def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[str, Any]:
|
| 243 |
# Get all required fields
|
| 244 |
fields = {}
|
| 245 |
logger.info(f"Starting forecast for lead_time: {lead_time} hours")
|
| 246 |
+
|
| 247 |
# Get surface fields
|
| 248 |
logger.info("Getting surface fields...")
|
| 249 |
fields.update(get_open_data(param=PARAM_SFC))
|
| 250 |
+
|
| 251 |
# Get soil fields and rename them
|
| 252 |
logger.info("Getting soil fields...")
|
| 253 |
soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
|
|
|
|
| 257 |
}
|
| 258 |
for k, v in soil.items():
|
| 259 |
fields[mapping[k]] = v
|
| 260 |
+
|
| 261 |
# Get pressure level fields
|
| 262 |
logger.info("Getting pressure level fields...")
|
| 263 |
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
|
| 264 |
+
|
| 265 |
# Convert geopotential height to geopotential
|
| 266 |
for level in LEVELS:
|
| 267 |
gh = fields.pop(f"gh_{level}")
|
| 268 |
fields[f"z_{level}"] = gh * 9.80665
|
| 269 |
+
|
| 270 |
input_state = dict(date=date, fields=fields)
|
| 271 |
+
|
| 272 |
# Use the global model instance
|
| 273 |
global MODEL
|
| 274 |
if device != MODEL.device:
|
| 275 |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
|
| 276 |
+
|
| 277 |
# Run the model and get the final state
|
| 278 |
final_state = None
|
| 279 |
for state in MODEL.run(input_state=input_state, lead_time=lead_time):
|
| 280 |
logger.info(f"\n😀 date={state['date']} latitudes={state['latitudes'].shape} "
|
| 281 |
f"longitudes={state['longitudes'].shape} fields={len(state['fields'])}")
|
| 282 |
+
|
| 283 |
# Log a few example variables to show we have all fields
|
| 284 |
for var in ['2t', 'msl', 't_1000', 'z_850']:
|
| 285 |
if var in state['fields']:
|
|
|
|
| 287 |
logger.info(f" {var:<6} shape={values.shape} "
|
| 288 |
f"min={np.min(values):.6f} "
|
| 289 |
f"max={np.max(values):.6f}")
|
| 290 |
+
|
| 291 |
final_state = state
|
| 292 |
+
|
| 293 |
logger.info(f"Final state contains {len(final_state['fields'])} variables")
|
| 294 |
return final_state
|
| 295 |
|
| 296 |
def get_available_variables(state):
|
| 297 |
"""Get available variables from the state and organize them into groups"""
|
| 298 |
available_vars = set(state['fields'].keys())
|
| 299 |
+
|
| 300 |
# Create dropdown choices only for available variables
|
| 301 |
choices = []
|
| 302 |
for group_name, variables in VARIABLE_GROUPS.items():
|
| 303 |
+
group_vars = [(f"{desc} ({var_id})", var_id)
|
| 304 |
+
for var_id, desc in variables.items()
|
| 305 |
if var_id in available_vars]
|
| 306 |
+
|
| 307 |
if group_vars: # Only add group if it has available variables
|
| 308 |
choices.append((f"── {group_name} ──", None))
|
| 309 |
choices.extend(group_vars)
|
| 310 |
+
|
| 311 |
return choices
|
| 312 |
|
| 313 |
+
def save_forecast_data(state, format='json'):
|
| 314 |
+
"""Save forecast data in specified format"""
|
| 315 |
+
if state is None:
|
| 316 |
+
raise ValueError("No forecast data available. Please run a forecast first.")
|
| 317 |
+
|
| 318 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 319 |
+
forecast_time = state['date'].strftime("%Y%m%d_%H") if isinstance(state['date'], datetime.datetime) else state['date']
|
| 320 |
+
|
| 321 |
+
# Use forecasts directory for all outputs
|
| 322 |
+
output_dir = TEMP_DIR / "forecasts"
|
| 323 |
+
|
| 324 |
+
if format == 'json':
|
| 325 |
+
# Create a JSON-serializable dictionary
|
| 326 |
+
data = {
|
| 327 |
+
'metadata': {
|
| 328 |
+
'forecast_date': forecast_time,
|
| 329 |
+
'export_date': datetime.datetime.now().isoformat(),
|
| 330 |
+
'total_points': len(state['latitudes']),
|
| 331 |
+
'total_variables': len(state['fields'])
|
| 332 |
+
},
|
| 333 |
+
'coordinates': {
|
| 334 |
+
'latitudes': state['latitudes'].tolist(),
|
| 335 |
+
'longitudes': state['longitudes'].tolist()
|
| 336 |
+
},
|
| 337 |
+
'fields': {
|
| 338 |
+
var_name: {
|
| 339 |
+
'values': values.tolist(),
|
| 340 |
+
'statistics': {
|
| 341 |
+
'min': float(np.min(values)),
|
| 342 |
+
'max': float(np.max(values)),
|
| 343 |
+
'mean': float(np.mean(values)),
|
| 344 |
+
'std': float(np.std(values))
|
| 345 |
+
}
|
| 346 |
+
}
|
| 347 |
+
for var_name, values in state['fields'].items()
|
| 348 |
+
}
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.json"
|
| 352 |
+
with open(output_file, 'w') as f:
|
| 353 |
+
json.dump(data, f, indent=2)
|
| 354 |
+
|
| 355 |
+
return str(output_file)
|
| 356 |
+
|
| 357 |
+
elif format == 'netcdf':
|
| 358 |
+
# Create an xarray Dataset
|
| 359 |
+
data_vars = {}
|
| 360 |
+
coords = {
|
| 361 |
+
'point': np.arange(len(state['latitudes'])),
|
| 362 |
+
'latitude': ('point', state['latitudes']),
|
| 363 |
+
'longitude': ('point', state['longitudes']),
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
# Add each field as a variable
|
| 367 |
+
for var_name, values in state['fields'].items():
|
| 368 |
+
data_vars[var_name] = (['point'], values)
|
| 369 |
+
|
| 370 |
+
# Create the dataset
|
| 371 |
+
ds = xr.Dataset(
|
| 372 |
+
data_vars=data_vars,
|
| 373 |
+
coords=coords,
|
| 374 |
+
attrs={
|
| 375 |
+
'forecast_date': forecast_time,
|
| 376 |
+
'export_date': datetime.datetime.now().isoformat(),
|
| 377 |
+
'description': 'AIFS Weather Forecast Data'
|
| 378 |
+
}
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.nc"
|
| 382 |
+
ds.to_netcdf(output_file)
|
| 383 |
+
|
| 384 |
+
return str(output_file)
|
| 385 |
+
|
| 386 |
+
elif format == 'csv':
|
| 387 |
+
# Create a DataFrame with lat/lon and all variables
|
| 388 |
+
df = pd.DataFrame({
|
| 389 |
+
'latitude': state['latitudes'],
|
| 390 |
+
'longitude': state['longitudes']
|
| 391 |
+
})
|
| 392 |
+
|
| 393 |
+
# Add each field as a column
|
| 394 |
+
for var_name, values in state['fields'].items():
|
| 395 |
+
df[var_name] = values
|
| 396 |
+
|
| 397 |
+
output_file = output_dir / f"forecast_{forecast_time}_{timestamp}.csv"
|
| 398 |
+
df.to_csv(output_file, index=False)
|
| 399 |
+
|
| 400 |
+
return str(output_file)
|
| 401 |
+
|
| 402 |
+
else:
|
| 403 |
+
raise ValueError(f"Unsupported format: {format}")
|
| 404 |
+
|
| 405 |
+
# Create dropdown choices with groups
|
| 406 |
+
DROPDOWN_CHOICES = []
|
| 407 |
+
for group_name, variables in VARIABLE_GROUPS.items():
|
| 408 |
+
# Add group separator
|
| 409 |
+
DROPDOWN_CHOICES.append((f"── {group_name} ──", None))
|
| 410 |
+
# Add variables in this group
|
| 411 |
+
for var_id, desc in sorted(variables.items()):
|
| 412 |
+
DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id))
|
| 413 |
+
|
| 414 |
def update_interface():
|
| 415 |
with gr.Blocks(css="""
|
| 416 |
.centered-header {
|
|
|
|
| 429 |
border-top: 1px solid #eee;
|
| 430 |
}
|
| 431 |
""") as demo:
|
| 432 |
+
forecast_state = gr.State(None)
|
| 433 |
+
|
| 434 |
+
# Header section
|
| 435 |
+
gr.Markdown(f"""
|
| 436 |
+
# AIFS Weather Forecast
|
| 437 |
+
|
| 438 |
+
<div class="subtitle">
|
| 439 |
+
Interactive visualization of ECMWF AIFS weather forecasts.<br>
|
| 440 |
+
Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),<br>
|
| 441 |
+
select how many hours ahead you want to forecast and which meteorological variable to visualize.
|
| 442 |
+
</div>
|
| 443 |
+
""")
|
| 444 |
|
| 445 |
with gr.Row():
|
| 446 |
with gr.Column(scale=1):
|
|
|
|
| 451 |
value=12,
|
| 452 |
label="Forecast Hours Ahead"
|
| 453 |
)
|
| 454 |
+
# Start with the original DROPDOWN_CHOICES
|
| 455 |
variable = gr.Dropdown(
|
| 456 |
+
choices=DROPDOWN_CHOICES, # Use original choices at startup
|
| 457 |
+
value="2t",
|
| 458 |
label="Select Variable to Plot"
|
| 459 |
)
|
| 460 |
with gr.Row():
|
| 461 |
clear_btn = gr.Button("Clear")
|
| 462 |
run_btn = gr.Button("Run Forecast", variant="primary")
|
| 463 |
|
| 464 |
+
download_nc = gr.Button("Download Forecast (NetCDF)")
|
| 465 |
+
download_output = gr.File(label="Download Output")
|
|
|
|
| 466 |
|
| 467 |
with gr.Column(scale=2):
|
| 468 |
forecast_output = gr.Image()
|
| 469 |
|
| 470 |
def run_and_store(lead_time):
|
| 471 |
"""Run forecast and store state"""
|
| 472 |
+
forecast_state = run_forecast(DEFAULT_DATE, lead_time, "cuda")
|
| 473 |
+
plot = plot_forecast(forecast_state, "2t") # Default to 2t
|
| 474 |
+
return forecast_state, plot
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
|
| 476 |
+
def update_plot_from_state(forecast_state, variable):
|
| 477 |
"""Update plot using stored state"""
|
| 478 |
+
if forecast_state is None or variable is None:
|
| 479 |
return None
|
| 480 |
try:
|
| 481 |
+
return plot_forecast(forecast_state, variable)
|
| 482 |
except KeyError as e:
|
| 483 |
logger.error(f"Variable {variable} not found in state: {e}")
|
| 484 |
return None
|
| 485 |
|
| 486 |
def clear():
|
| 487 |
"""Clear everything"""
|
| 488 |
+
return [None, None, 12, "2t"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
|
| 490 |
+
def save_netcdf(forecast_state):
|
| 491 |
+
"""Save forecast data as NetCDF"""
|
| 492 |
+
if forecast_state is None:
|
| 493 |
+
raise ValueError("No forecast data available. Please run a forecast first.")
|
| 494 |
+
|
| 495 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 496 |
+
forecast_time = forecast_state['date'].strftime("%Y%m%d_%H") if isinstance(forecast_state['date'], datetime.datetime) else forecast_state['date']
|
| 497 |
+
|
| 498 |
+
# Create an xarray Dataset
|
| 499 |
+
data_vars = {}
|
| 500 |
+
coords = {
|
| 501 |
+
'point': np.arange(len(forecast_state['latitudes'])),
|
| 502 |
+
'latitude': ('point', forecast_state['latitudes']),
|
| 503 |
+
'longitude': ('point', forecast_state['longitudes']),
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
# Add each field as a variable
|
| 507 |
+
for var_name, values in forecast_state['fields'].items():
|
| 508 |
+
data_vars[var_name] = (['point'], values)
|
| 509 |
+
|
| 510 |
+
# Create the dataset
|
| 511 |
+
ds = xr.Dataset(
|
| 512 |
+
data_vars=data_vars,
|
| 513 |
+
coords=coords,
|
| 514 |
+
attrs={
|
| 515 |
+
'forecast_date': forecast_time,
|
| 516 |
+
'export_date': datetime.datetime.now().isoformat(),
|
| 517 |
+
'description': 'AIFS Weather Forecast Data'
|
| 518 |
+
}
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
output_file = TEMP_DIR / "forecasts" / f"forecast_{forecast_time}_{timestamp}.nc"
|
| 522 |
+
ds.to_netcdf(output_file)
|
| 523 |
+
|
| 524 |
+
return str(output_file)
|
| 525 |
|
| 526 |
# Connect the components
|
| 527 |
run_btn.click(
|
| 528 |
fn=run_and_store,
|
| 529 |
inputs=[lead_time],
|
| 530 |
+
outputs=[forecast_state, forecast_output]
|
| 531 |
)
|
| 532 |
|
| 533 |
variable.change(
|
| 534 |
fn=update_plot_from_state,
|
| 535 |
+
inputs=[forecast_state, variable],
|
| 536 |
outputs=forecast_output
|
| 537 |
)
|
| 538 |
|
| 539 |
clear_btn.click(
|
| 540 |
fn=clear,
|
| 541 |
inputs=[],
|
| 542 |
+
outputs=[forecast_state, forecast_output, lead_time, variable]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
)
|
| 544 |
|
| 545 |
download_nc.click(
|
| 546 |
fn=save_netcdf,
|
| 547 |
+
inputs=[forecast_state],
|
| 548 |
+
outputs=[download_output]
|
| 549 |
)
|
| 550 |
|
| 551 |
return demo
|
|
|
|
| 553 |
# Create and launch the interface
|
| 554 |
demo = update_interface()
|
| 555 |
demo.launch()
|
| 556 |
+
|
| 557 |
+
def setup_directories():
|
| 558 |
+
"""Create necessary directories with .keep files"""
|
| 559 |
+
# Define all required directories
|
| 560 |
+
directories = {
|
| 561 |
+
TEMP_DIR / "data_cache": "Cache directory for downloaded weather data",
|
| 562 |
+
TEMP_DIR / "forecasts": "Directory for forecast outputs (plots and data files)",
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
# Create directories and .keep files
|
| 566 |
+
for directory, description in directories.items():
|
| 567 |
+
directory.mkdir(parents=True, exist_ok=True)
|
| 568 |
+
keep_file = directory / ".keep"
|
| 569 |
+
if not keep_file.exists():
|
| 570 |
+
keep_file.write_text(f"# {description}\n# This file ensures the directory is tracked in git\n")
|
| 571 |
+
logger.info(f"Created directory and .keep file: {directory}")
|
| 572 |
+
|
| 573 |
+
# Call it during initialization
|
| 574 |
+
setup_directories()
|
| 575 |
+
|
| 576 |
+
def cleanup_old_files():
|
| 577 |
+
"""Remove old temporary and cache files"""
|
| 578 |
+
current_time = datetime.datetime.now().timestamp()
|
| 579 |
+
|
| 580 |
+
# Clean up forecast files (1 hour old)
|
| 581 |
+
forecast_dir = TEMP_DIR / "forecasts"
|
| 582 |
+
for file in forecast_dir.glob("*.*"):
|
| 583 |
+
if file.name == ".keep":
|
| 584 |
+
continue
|
| 585 |
+
if current_time - file.stat().st_mtime > 3600:
|
| 586 |
+
logger.info(f"Removing old forecast file: {file}")
|
| 587 |
+
file.unlink(missing_ok=True)
|
| 588 |
+
|
| 589 |
+
# Clean up cache files (24 hours old)
|
| 590 |
+
cache_dir = TEMP_DIR / "data_cache"
|
| 591 |
+
for file in cache_dir.glob("*.pkl"):
|
| 592 |
+
if file.name == ".keep":
|
| 593 |
+
continue
|
| 594 |
+
if current_time - file.stat().st_mtime > 86400:
|
| 595 |
+
logger.info(f"Removing old cache file: {file}")
|
| 596 |
+
file.unlink(missing_ok=True)
|
gradio_temp/.keep
CHANGED
|
File without changes
|