|
|
|
|
|
import os, time, json, random |
|
|
import requests |
|
|
import gradio as gr |
|
|
|
|
|
PROVIDER = os.getenv("PROVIDER", "hf_model").strip() |
|
|
MODEL_ID = os.getenv("MODEL_ID", "MBZUAI-IFM/K2-Think-SFT").strip() |
|
|
HF_TOKEN = os.getenv("HF_TOKEN", "").strip() |
|
|
|
|
|
def _get(url, params=None, headers=None, timeout=12, retries=2, backoff=1.6): |
|
|
for i in range(retries + 1): |
|
|
try: |
|
|
r = requests.get(url, params=params, headers=headers, timeout=timeout) |
|
|
r.raise_for_status() |
|
|
return r |
|
|
except Exception: |
|
|
if i == retries: |
|
|
raise |
|
|
time.sleep((backoff ** i) + random.random() * 0.25) |
|
|
|
|
|
def geocode_city(city:str): |
|
|
r = _get("https://nominatim.openstreetmap.org/search", |
|
|
params={"q": city, "format": "json", "limit": 1}, |
|
|
headers={"User-Agent": "climamind-space"}) |
|
|
j = r.json() |
|
|
if not j: |
|
|
raise RuntimeError("City not found") |
|
|
return {"lat": float(j[0]["lat"]), "lon": float(j[0]["lon"]), "name": j[0]["display_name"]} |
|
|
|
|
|
def fetch_open_meteo(lat, lon): |
|
|
r = _get("https://api.open-meteo.com/v1/forecast", params={ |
|
|
"latitude": lat, "longitude": lon, |
|
|
"current": "temperature_2m,relative_humidity_2m,wind_speed_10m,precipitation,uv_index", |
|
|
"hourly": "temperature_2m,relative_humidity_2m,wind_speed_10m,precipitation_probability,uv_index", |
|
|
"timezone": "auto" |
|
|
}) |
|
|
return r.json() |
|
|
|
|
|
def fetch_openaq_pm25(lat, lon): |
|
|
r = _get("https://api.openaq.org/v3/latest", |
|
|
params={"coordinates": f"{lat},{lon}", "radius": 10000, "limit": 1, "parameter": "pm25"}, |
|
|
headers={"User-Agent": "climamind-space"}) |
|
|
j = r.json() |
|
|
pm25 = None |
|
|
if j.get("results"): |
|
|
ms = j["results"][0].get("measurements", []) |
|
|
for m in ms: |
|
|
if m.get("parameter") == "pm25": |
|
|
pm25 = m.get("value") |
|
|
break |
|
|
return pm25 |
|
|
|
|
|
def fetch_factors(lat, lon): |
|
|
wx = fetch_open_meteo(lat, lon) |
|
|
cur = wx.get("current", {}) |
|
|
factors = { |
|
|
"temp_c": cur.get("temperature_2m"), |
|
|
"rh": cur.get("relative_humidity_2m"), |
|
|
"wind_kmh": cur.get("wind_speed_10m"), |
|
|
"precip_mm": cur.get("precipitation"), |
|
|
"uv": cur.get("uv_index"), |
|
|
"pm25": fetch_openaq_pm25(lat, lon) |
|
|
} |
|
|
return {"factors": factors, "raw": wx} |
|
|
|
|
|
def drying_index(temp_c, rh, wind_kmh, cloud_frac=None): |
|
|
base = (temp_c or 0) * 1.2 + (wind_kmh or 0) * 0.8 - (rh or 0) * 0.9 |
|
|
if cloud_frac is not None: |
|
|
base -= 20 * cloud_frac |
|
|
return max(0, min(100, round(base))) |
|
|
|
|
|
def heat_stress_index(temp_c, rh, wind_kmh): |
|
|
hs = (temp_c or 0) * 1.1 + (rh or 0) * 0.3 - (wind_kmh or 0) * 0.2 |
|
|
return max(0, min(100, round(hs))) |
|
|
|
|
|
PROMPT = """You are ClimaMind, a climate reasoning assistant. Use ONLY the observations provided and return STRICT JSON. |
|
|
|
|
|
Location: {loc} (lat={lat}, lon={lon}), local time: {t_local} |
|
|
Observations: temp={temp_c}°C, rh={rh}%, wind={wind_kmh} km/h, precip={precip_mm} mm, uv={uv}, pm25={pm25} |
|
|
Derived: drying_index={d_idx}, heat_stress_index={hs_idx} |
|
|
|
|
|
Task: Answer the user’s query: "{query}" for the next 24 hours. |
|
|
Steps: |
|
|
1) Identify the relevant factors. |
|
|
2) Reason causally (2–3 steps). |
|
|
3) Give a concise recommendation with time window(s) and a confidence. |
|
|
4) Output a short WHY-TRACE (3 bullets). |
|
|
Return JSON ONLY: |
|
|
{{ |
|
|
"answer": "...", |
|
|
"why_trace": ["...", "...", "..."], |
|
|
"risk_badge": "Low"|"Moderate"|"High" |
|
|
}}""" |
|
|
|
|
|
def call_stub(_prompt:str)->str: |
|
|
return json.dumps({ |
|
|
"answer": "Based on 32°C, 50% RH and 12 km/h wind, cotton dries in ~2–3h (faster after 2pm).", |
|
|
"why_trace": [ |
|
|
"Higher temperature and wind increase evaporation rate", |
|
|
"Moderate humidity slightly slows drying", |
|
|
"Lower afternoon cloud cover speeds it up" |
|
|
], |
|
|
"risk_badge": "Low" |
|
|
}) |
|
|
|
|
|
def call_hf_model(prompt:str)->str: |
|
|
from huggingface_hub import InferenceClient |
|
|
client = InferenceClient(model=MODEL_ID, token=(HF_TOKEN or None)) |
|
|
out = client.text_generation( |
|
|
prompt, |
|
|
max_new_tokens=200, |
|
|
temperature=0.1, |
|
|
repetition_penalty=1.05, |
|
|
do_sample=False, |
|
|
) |
|
|
return str(out) |
|
|
|
|
|
_local_loaded = False |
|
|
def _ensure_local_loaded(): |
|
|
|
|
|
global _local_loaded, tokenizer, model |
|
|
if _local_loaded: |
|
|
return |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
|
import torch |
|
|
bnb_cfg = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
trust_remote_code=True, |
|
|
device_map="auto", |
|
|
quantization_config=bnb_cfg, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
_local_loaded = True |
|
|
|
|
|
def call_local(prompt:str)->str: |
|
|
_ensure_local_loaded() |
|
|
import torch |
|
|
if hasattr(tokenizer, "apply_chat_template"): |
|
|
messages = [{"role":"user","content":prompt}] |
|
|
inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt").to(model.device) |
|
|
else: |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=200, |
|
|
temperature=0.1, |
|
|
do_sample=False, |
|
|
repetition_penalty=1.05, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
return tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
def reason_answer(loc, coords, factors, query): |
|
|
d_idx = drying_index(factors.get("temp_c"), factors.get("rh"), factors.get("wind_kmh")) |
|
|
hs_idx = heat_stress_index(factors.get("temp_c"), factors.get("rh"), factors.get("wind_kmh")) |
|
|
t_local = time.strftime("%Y-%m-%d %H:%M") |
|
|
prompt = PROMPT.format( |
|
|
loc=loc, lat=coords["lat"], lon=coords["lon"], t_local=t_local, |
|
|
temp_c=factors.get("temp_c"), rh=factors.get("rh"), wind_kmh=factors.get("wind_kmh"), |
|
|
precip_mm=factors.get("precip_mm"), uv=factors.get("uv"), pm25=factors.get("pm25"), |
|
|
d_idx=d_idx, hs_idx=hs_idx, query=query |
|
|
) |
|
|
|
|
|
if PROVIDER == "hf_model": |
|
|
raw = call_hf_model(prompt) |
|
|
elif PROVIDER == "local": |
|
|
raw = call_local(prompt) |
|
|
else: |
|
|
raw = call_stub(prompt) |
|
|
|
|
|
start, end = raw.find("{"), raw.rfind("}") |
|
|
if start == -1 or end == -1: |
|
|
return { |
|
|
"answer": "The reasoning service returned non-JSON text. Please try again.", |
|
|
"why_trace": ["Response formatting issue", "Low temperature helps", "Retry the query"], |
|
|
"risk_badge": "Low" |
|
|
} |
|
|
try: |
|
|
return json.loads(raw[start:end+1]) |
|
|
except Exception: |
|
|
return { |
|
|
"answer": "Failed to parse JSON from model output.", |
|
|
"why_trace": ["JSON parsing error", "Reduce tokens/temperature", "Retry once"], |
|
|
"risk_badge": "Low" |
|
|
} |
|
|
|
|
|
def app(city, question): |
|
|
geo = geocode_city(city) |
|
|
data = fetch_factors(geo["lat"], geo["lon"]) |
|
|
ans = reason_answer( |
|
|
geo["name"], {"lat": geo["lat"], "lon": geo["lon"]}, |
|
|
data["factors"], question |
|
|
) |
|
|
fx = ", ".join([f"{k}={v}" for k, v in data["factors"].items()]) |
|
|
why_list = ans.get("why_trace") or [] |
|
|
why = "\n• " + "\n• ".join(why_list) if why_list else "\n• (no trace returned)" |
|
|
md = ( |
|
|
f"**Answer:** {ans.get('answer','(no answer)')}\n\n" |
|
|
f"**Why-trace:**{why}\n\n" |
|
|
f"**Risk:** {ans.get('risk_badge','N/A')}\n\n" |
|
|
f"**Factors:** {fx}" |
|
|
) |
|
|
return md |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=app, |
|
|
inputs=[ |
|
|
gr.Textbox(label="City", value="New Delhi"), |
|
|
gr.Dropdown( |
|
|
choices=[ |
|
|
"If I wash clothes now, when will they dry?", |
|
|
"Should I water my plants today or wait?", |
|
|
"What is the heat/wildfire risk today? Explain briefly." |
|
|
], |
|
|
label="Question", |
|
|
value="If I wash clothes now, when will they dry?" |
|
|
) |
|
|
], |
|
|
outputs=gr.Markdown(label="ClimaMind"), |
|
|
title="ClimaMind — K2-Think + Live Climate Data", |
|
|
description="Provider = hf_model (Inference API) | local (GPU Space) | stub (offline). Configure env in Space settings.", |
|
|
allow_flagging="never" |
|
|
) |
|
|
demo.queue(concurrency_count=2, max_size=8) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|