Spaces:
Build error
Build error
| import os | |
| # Set memory optimization environment variables | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
| os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16' | |
| import gradio as gr | |
| import datetime | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import cartopy.crs as ccrs | |
| import cartopy.feature as cfeature | |
| import matplotlib.tri as tri | |
| from anemoi.inference.runners.simple import SimpleRunner | |
| from ecmwf.opendata import Client as OpendataClient | |
| import earthkit.data as ekd | |
| import earthkit.regrid as ekr | |
| # Define parameters (updating to match notebook.py) | |
| PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"] | |
| PARAM_SOIL = ["vsw", "sot"] | |
| PARAM_PL = ["gh", "t", "u", "v", "w", "q"] | |
| LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50] | |
| SOIL_LEVELS = [1, 2] | |
| DEFAULT_DATE = OpendataClient().latest() | |
| def get_open_data(param, levelist=[]): | |
| fields = {} | |
| # Get the data for the current date and the previous date | |
| for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]: | |
| data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist) | |
| for f in data: | |
| assert f.to_numpy().shape == (721, 1440) | |
| values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1) | |
| values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"}) | |
| name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param") | |
| if name not in fields: | |
| fields[name] = [] | |
| fields[name].append(values) | |
| # Create a single matrix for each parameter | |
| for param, values in fields.items(): | |
| fields[param] = np.stack(values) | |
| return fields | |
| def run_forecast(date, lead_time, device): | |
| # Get all required fields | |
| fields = {} | |
| # Get surface fields | |
| fields.update(get_open_data(param=PARAM_SFC)) | |
| # Get soil fields and rename them | |
| soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS) | |
| mapping = { | |
| 'sot_1': 'stl1', 'sot_2': 'stl2', | |
| 'vsw_1': 'swvl1', 'vsw_2': 'swvl2' | |
| } | |
| for k, v in soil.items(): | |
| fields[mapping[k]] = v | |
| # Get pressure level fields | |
| fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS)) | |
| # Convert geopotential height to geopotential | |
| for level in LEVELS: | |
| gh = fields.pop(f"gh_{level}") | |
| fields[f"z_{level}"] = gh * 9.80665 | |
| input_state = dict(date=date, fields=fields) | |
| runner = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device) | |
| results = [] | |
| for state in runner.run(input_state=input_state, lead_time=lead_time): | |
| results.append(state) | |
| return results[-1] | |
| def plot_forecast(state): | |
| latitudes, longitudes = state["latitudes"], state["longitudes"] | |
| values = state["fields"]["100u"] | |
| fig, ax = plt.subplots(figsize=(11, 6), subplot_kw={"projection": ccrs.PlateCarree()}) | |
| ax.coastlines() | |
| ax.add_feature(cfeature.BORDERS, linestyle=":") | |
| triangulation = tri.Triangulation(longitudes, latitudes) | |
| contour = ax.tricontourf(triangulation, values, levels=20, transform=ccrs.PlateCarree(), cmap="RdBu") | |
| plt.title(f"100m winds at {state['date']}") | |
| plt.colorbar(contour) | |
| return fig | |
| def gradio_interface(date_str, lead_time, device): | |
| try: | |
| date = datetime.datetime.strptime(date_str, "%Y-%m-%d") | |
| except ValueError: | |
| raise gr.Error("Please enter a valid date in YYYY-MM-DD format") | |
| state = run_forecast(date, lead_time, device) | |
| return plot_forecast(state) | |
| demo = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[ | |
| gr.Textbox(value=DEFAULT_DATE.strftime("%Y-%m-%d"), label="Forecast Date (YYYY-MM-DD)"), | |
| gr.Slider(minimum=6, maximum=48, step=6, value=12, label="Lead Time (Hours)"), | |
| gr.Radio(choices=["cuda", "cpu"], value="cuda", label="Compute Device") | |
| ], | |
| outputs=gr.Plot(), | |
| title="AIFS Weather Forecast", | |
| description="Run ECMWF AIFS forecasts based on selected parameters." | |
| ) | |
| demo.launch() | |