AnujithM commited on
Commit
1725128
·
verified ·
1 Parent(s): 8b1992d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -38
app.py CHANGED
@@ -1,7 +1,7 @@
1
  # app.py — ClimaMind on Hugging Face Spaces (Gradio)
2
- # Modes:
3
- # PROVIDER=hf_model (default) -> calls HF Inference API for K2 (recommended for demo)
4
- # PROVIDER=local -> loads model with transformers (requires GPU Space)
5
  # PROVIDER=stub -> offline canned answers
6
 
7
  import os, time, json, random
@@ -9,9 +9,10 @@ import requests
9
  import gradio as gr
10
 
11
  # -------- Config --------
12
- PROVIDER = os.getenv("PROVIDER", "hf_model").strip()
13
- MODEL_ID = os.getenv("MODEL_ID", "MBZUAI-IFM/K2-Think-SFT").strip()
14
- HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
 
15
 
16
  # -------- HTTP helper --------
17
  def _get(url, params=None, headers=None, timeout=12, retries=2, backoff=1.6):
@@ -45,21 +46,20 @@ def fetch_open_meteo(lat, lon):
45
  })
46
  return r.json()
47
 
48
- # -------- PM2.5 (Open-Meteo Air-Quality, free; replaces OpenAQ v3 which now needs a key) --------
49
  def fetch_pm25(lat, lon):
50
  try:
51
  r = _get("https://air-quality-api.open-meteo.com/v1/air-quality", params={
52
  "latitude": lat, "longitude": lon, "hourly": "pm2_5", "timezone": "auto"
53
  }, headers={"User-Agent": "climamind-space"})
54
  j = r.json()
55
- # take the most recent hour
56
  hourly = j.get("hourly", {})
57
  values = hourly.get("pm2_5") or []
58
  if values:
59
  return values[-1]
60
  except Exception:
61
  pass
62
- return None # graceful fallback
63
 
64
  def fetch_factors(lat, lon):
65
  wx = fetch_open_meteo(lat, lon)
@@ -85,7 +85,7 @@ def heat_stress_index(temp_c, rh, wind_kmh):
85
  hs = (temp_c or 0) * 1.1 + (rh or 0) * 0.3 - (wind_kmh or 0) * 0.2
86
  return max(0, min(100, round(hs)))
87
 
88
- # -------- Prompt --------
89
  PROMPT = """You are ClimaMind, a climate reasoning assistant. Use ONLY the observations provided and return STRICT JSON.
90
 
91
  Location: {loc} (lat={lat}, lon={lon}), local time: {t_local}
@@ -118,17 +118,26 @@ def call_stub(_prompt:str)->str:
118
  "risk_badge": "Low"
119
  })
120
 
121
- def call_hf_model(prompt:str)->str:
 
122
  from huggingface_hub import InferenceClient
123
- client = InferenceClient(model=MODEL_ID, token=(HF_TOKEN or None))
124
- out = client.text_generation(
125
- prompt,
126
- max_new_tokens=200,
127
- temperature=0.1,
128
- repetition_penalty=1.05,
129
- do_sample=False,
130
- )
131
- return str(out)
 
 
 
 
 
 
 
 
132
 
133
  _local_loaded = False
134
  def _ensure_local_loaded():
@@ -153,8 +162,9 @@ def _ensure_local_loaded():
153
  )
154
  _local_loaded = True
155
 
156
- def call_local(prompt:str)->str:
157
  _ensure_local_loaded()
 
158
  if hasattr(tokenizer, "apply_chat_template"):
159
  messages = [{"role":"user","content":prompt}]
160
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt").to(model.device)
@@ -169,7 +179,7 @@ def call_local(prompt:str)->str:
169
  repetition_penalty=1.05,
170
  eos_token_id=tokenizer.eos_token_id,
171
  )
172
- return tokenizer.decode(out[0], skip_special_tokens=True)
173
 
174
  def reason_answer(loc, coords, factors, query):
175
  d_idx = drying_index(factors.get("temp_c"), factors.get("rh"), factors.get("wind_kmh"))
@@ -183,27 +193,35 @@ def reason_answer(loc, coords, factors, query):
183
  )
184
 
185
  if PROVIDER == "hf_model":
186
- raw = call_hf_model(prompt)
 
 
 
 
187
  elif PROVIDER == "local":
188
- raw = call_local(prompt)
189
  else:
190
- raw = call_stub(prompt)
191
 
 
192
  start, end = raw.find("{"), raw.rfind("}")
193
  if start == -1 or end == -1:
194
- return {
195
  "answer": "The reasoning service returned non-JSON text. Please try again.",
196
- "why_trace": ["Response formatting issue", "Low temperature helps", "Retry the query"],
197
- "risk_badge": "Low"
198
- }
199
- try:
200
- return json.loads(raw[start:end+1])
201
- except Exception:
202
- return {
203
- "answer": "Failed to parse JSON from model output.",
204
- "why_trace": ["JSON parsing error", "Reduce tokens/temperature", "Retry once"],
205
  "risk_badge": "Low"
206
  }
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  # -------- Gradio UI --------
209
  def app(city, question):
@@ -216,11 +234,13 @@ def app(city, question):
216
  fx = ", ".join([f"{k}={v}" for k, v in data["factors"].items()])
217
  why_list = ans.get("why_trace") or []
218
  why = "\n• " + "\n• ".join(why_list) if why_list else "\n• (no trace returned)"
 
219
  md = (
220
  f"**Answer:** {ans.get('answer','(no answer)')}\n\n"
221
  f"**Why-trace:**{why}\n\n"
222
  f"**Risk:** {ans.get('risk_badge','N/A')}\n\n"
223
- f"**Factors:** {fx}"
 
224
  )
225
  return md
226
 
@@ -240,8 +260,8 @@ demo = gr.Interface(
240
  ],
241
  outputs=gr.Markdown(label="ClimaMind"),
242
  title="ClimaMind — K2-Think + Live Climate Data",
243
- description="Provider = hf_model (Inference API) | local (GPU Space) | stub (offline). Configure env in Space settings.",
244
- allow_flagging="never",
245
  concurrency_limit=2,
246
  )
247
 
 
1
  # app.py — ClimaMind on Hugging Face Spaces (Gradio)
2
+ # Providers:
3
+ # PROVIDER=hf_model (default) -> calls HF Inference API (tries MODEL_ID then ALT_MODEL_ID)
4
+ # PROVIDER=local -> loads model in Space (requires GPU)
5
  # PROVIDER=stub -> offline canned answers
6
 
7
  import os, time, json, random
 
9
  import gradio as gr
10
 
11
  # -------- Config --------
12
+ PROVIDER = os.getenv("PROVIDER", "hf_model").strip()
13
+ MODEL_ID = os.getenv("MODEL_ID", "LLM360/K2-Think").strip() # default = public K2
14
+ ALT_MODEL_ID = os.getenv("ALT_MODEL_ID", "Qwen/Qwen2.5-7B-Instruct").strip() # fallback that works on serverless
15
+ HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
16
 
17
  # -------- HTTP helper --------
18
  def _get(url, params=None, headers=None, timeout=12, retries=2, backoff=1.6):
 
46
  })
47
  return r.json()
48
 
49
+ # -------- PM2.5 (Open-Meteo Air-Quality, free) --------
50
  def fetch_pm25(lat, lon):
51
  try:
52
  r = _get("https://air-quality-api.open-meteo.com/v1/air-quality", params={
53
  "latitude": lat, "longitude": lon, "hourly": "pm2_5", "timezone": "auto"
54
  }, headers={"User-Agent": "climamind-space"})
55
  j = r.json()
 
56
  hourly = j.get("hourly", {})
57
  values = hourly.get("pm2_5") or []
58
  if values:
59
  return values[-1]
60
  except Exception:
61
  pass
62
+ return None
63
 
64
  def fetch_factors(lat, lon):
65
  wx = fetch_open_meteo(lat, lon)
 
85
  hs = (temp_c or 0) * 1.1 + (rh or 0) * 0.3 - (wind_kmh or 0) * 0.2
86
  return max(0, min(100, round(hs)))
87
 
88
+ # -------- Prompt (escape literal braces in JSON) --------
89
  PROMPT = """You are ClimaMind, a climate reasoning assistant. Use ONLY the observations provided and return STRICT JSON.
90
 
91
  Location: {loc} (lat={lat}, lon={lon}), local time: {t_local}
 
118
  "risk_badge": "Low"
119
  })
120
 
121
+ # Try HF Inference (MODEL_ID -> ALT_MODEL_ID), return (text, model_used)
122
+ def call_hf_model(prompt:str) -> tuple[str, str]:
123
  from huggingface_hub import InferenceClient
124
+ attempts = [m for m in [MODEL_ID, ALT_MODEL_ID] if m]
125
+ for mid in attempts:
126
+ try:
127
+ client = InferenceClient(model=mid, token=(HF_TOKEN or None))
128
+ out = client.text_generation(
129
+ prompt,
130
+ max_new_tokens=200,
131
+ temperature=0.1,
132
+ repetition_penalty=1.05,
133
+ do_sample=False,
134
+ )
135
+ return str(out), mid
136
+ except Exception as e:
137
+ print(f"[HF_MODEL] Failed on {mid}: {repr(e)}")
138
+ continue
139
+ # If all failed, raise so we can stub
140
+ raise RuntimeError(f"No serverless provider available. Tried: {attempts}")
141
 
142
  _local_loaded = False
143
  def _ensure_local_loaded():
 
162
  )
163
  _local_loaded = True
164
 
165
+ def call_local(prompt:str)->tuple[str, str]:
166
  _ensure_local_loaded()
167
+ import torch # import here to avoid dependency if not used
168
  if hasattr(tokenizer, "apply_chat_template"):
169
  messages = [{"role":"user","content":prompt}]
170
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt").to(model.device)
 
179
  repetition_penalty=1.05,
180
  eos_token_id=tokenizer.eos_token_id,
181
  )
182
+ return tokenizer.decode(out[0], skip_special_tokens=True), MODEL_ID
183
 
184
  def reason_answer(loc, coords, factors, query):
185
  d_idx = drying_index(factors.get("temp_c"), factors.get("rh"), factors.get("wind_kmh"))
 
193
  )
194
 
195
  if PROVIDER == "hf_model":
196
+ try:
197
+ raw, model_used = call_hf_model(prompt)
198
+ except Exception as e:
199
+ print("[HF_MODEL] Falling back to stub:", repr(e))
200
+ raw, model_used = call_stub(prompt), "stub"
201
  elif PROVIDER == "local":
202
+ raw, model_used = call_local(prompt)
203
  else:
204
+ raw, model_used = call_stub(prompt), "stub"
205
 
206
+ # Extract JSON
207
  start, end = raw.find("{"), raw.rfind("}")
208
  if start == -1 or end == -1:
209
+ parsed = {
210
  "answer": "The reasoning service returned non-JSON text. Please try again.",
211
+ "why_trace": ["Response formatting issue", "Keep temperature low", "Retry once"],
 
 
 
 
 
 
 
 
212
  "risk_badge": "Low"
213
  }
214
+ else:
215
+ try:
216
+ parsed = json.loads(raw[start:end+1])
217
+ except Exception:
218
+ parsed = {
219
+ "answer": "Failed to parse JSON from model output.",
220
+ "why_trace": ["JSON parsing error", "Reduce tokens/temperature", "Retry once"],
221
+ "risk_badge": "Low"
222
+ }
223
+ parsed["_model_used"] = model_used
224
+ return parsed
225
 
226
  # -------- Gradio UI --------
227
  def app(city, question):
 
234
  fx = ", ".join([f"{k}={v}" for k, v in data["factors"].items()])
235
  why_list = ans.get("why_trace") or []
236
  why = "\n• " + "\n• ".join(why_list) if why_list else "\n• (no trace returned)"
237
+ model_used = ans.pop("_model_used", "unknown")
238
  md = (
239
  f"**Answer:** {ans.get('answer','(no answer)')}\n\n"
240
  f"**Why-trace:**{why}\n\n"
241
  f"**Risk:** {ans.get('risk_badge','N/A')}\n\n"
242
+ f"**Factors:** {fx}\n\n"
243
+ f"<sub>Provider: {PROVIDER} • Model: `{model_used}`</sub>"
244
  )
245
  return md
246
 
 
260
  ],
261
  outputs=gr.Markdown(label="ClimaMind"),
262
  title="ClimaMind — K2-Think + Live Climate Data",
263
+ description="Serverless tries K2, falls back to Qwen if needed; or run locally on GPU Space. Stub as last resort.",
264
+ flagging_mode="never",
265
  concurrency_limit=2,
266
  )
267