burtenshaw HF Staff commited on
Commit
5dff91b
·
verified ·
1 Parent(s): c785470

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +339 -0
app.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Play any Atari game using a Vision-Language Model via the Hugging Face Router API.
4
+
5
+ The script:
6
+ 1. Starts an Atari environment (Docker) for the selected game
7
+ 2. Sends recent screen frames to a vision-language model
8
+ 3. Parses the model's integer response into an Atari action id
9
+ 4. Reports a minimal summary
10
+
11
+ Notes:
12
+ - Frames are sent raw (no overlays, cropping, or resizing)
13
+ - The model receives the legal action ids each step and must return one integer
14
+
15
+ Usage:
16
+ export API_KEY=your_hf_token_here
17
+ python examples/atari_pong_inference.py --game breakout --model Qwen/Qwen3-VL-8B-Instruct:novita
18
+ """
19
+
20
+ import os
21
+ import re
22
+ import base64
23
+ import gradio as gr
24
+ from collections import deque
25
+ from io import BytesIO
26
+ from typing import Deque, List, Optional
27
+
28
+ import numpy as np
29
+ from PIL import Image
30
+ from openai import OpenAI
31
+
32
+ from envs.atari_env import AtariEnv, AtariAction
33
+
34
+
35
+ # API Configuration
36
+ # For HuggingFace: Use HF_TOKEN and set API_BASE_URL
37
+ API_BASE_URL = "https://router.huggingface.co/v1" # Hugging Face Router endpoint
38
+ API_KEY = os.getenv("API_KEY") # Required for Hugging Face
39
+ ATARI_ENV_BASE_URL = os.getenv("ATARI_ENV_BASE_URL") # Optional: connect to a remote Atari env
40
+
41
+ # Vision-Language Model (Hugging Face Router compatible)
42
+ MODEL = "Qwen/Qwen3-VL-8B-Instruct:novita"
43
+
44
+ # Configuration
45
+ TEMPERATURE = 0.7
46
+ MAX_STEPS_PER_GAME = 10000
47
+ MAX_TOKENS = 16
48
+ VERBOSE = True
49
+ FRAME_HISTORY_LENGTH = 4
50
+ DISPLAY_SCALE = 3 # Scale factor for enlarging frames sent to UI
51
+ MODEL_SCALE = 3 # Scale factor for enlarging frames sent to the model
52
+
53
+ # Generic game prompt for the vision model
54
+ VISION_PROMPT = (
55
+ "You are playing an Atari-style game. You will be given recent frames "
56
+ "and the list of legal action ids for the current step. "
57
+ "Respond with a single integer that is exactly one of the legal action ids. "
58
+ "Do not include any words or punctuation — only the integer."
59
+ )
60
+
61
+ ACTIONS_LOOKUP = {
62
+ 0: "NOOP",
63
+ 1: "FIRE",
64
+ 2: "UP",
65
+ 3: "RIGHT",
66
+ 4: "LEFT",
67
+ 5: "DOWN",
68
+ 6: "UPRIGHT",
69
+ 7: "UPLEFT",
70
+ 8: "DOWNRIGHT",
71
+ 9: "DOWNLEFT",
72
+ 10: "UPFIRE",
73
+ 11: "RIGHTFIRE",
74
+ 12: "LEFTFIRE",
75
+ 13: "DOWNFIRE",
76
+ 14: "UPRIGHTFIRE",
77
+ 15: "UPLEFTFIRE",
78
+ 16: "DOWNRIGHTFIRE",
79
+ 17: "DOWNLEFTFIRE",
80
+ }
81
+
82
+ def screen_to_base64(screen: List[int], screen_shape: List[int]) -> str:
83
+ """Convert flattened screen array to base64 encoded PNG image (no processing)."""
84
+ screen_array = np.array(screen, dtype=np.uint8).reshape(screen_shape)
85
+ image = Image.fromarray(screen_array)
86
+ # Enlarge image for model input if configured
87
+ try:
88
+ if MODEL_SCALE and MODEL_SCALE > 1:
89
+ image = image.resize((image.width * MODEL_SCALE, image.height * MODEL_SCALE), Image.NEAREST)
90
+ except Exception:
91
+ pass
92
+ buffer = BytesIO()
93
+ image.save(buffer, format='PNG')
94
+ buffer.seek(0)
95
+ return base64.b64encode(buffer.read()).decode('utf-8')
96
+
97
+
98
+ def screen_to_numpy(screen: List[int], screen_shape: List[int]) -> np.ndarray:
99
+ """Convert flattened screen to a larger RGB numpy array for gr.Image display."""
100
+ arr = np.array(screen, dtype=np.uint8).reshape(screen_shape)
101
+ if len(screen_shape) == 3:
102
+ img = Image.fromarray(arr, mode='RGB')
103
+ else:
104
+ img = Image.fromarray(arr, mode='L')
105
+ # Enlarge with nearest-neighbor to preserve pixel edges
106
+ try:
107
+ img = img.resize((img.width * DISPLAY_SCALE, img.height * DISPLAY_SCALE), Image.NEAREST)
108
+ except Exception:
109
+ pass
110
+ if img.mode != 'RGB':
111
+ img = img.convert('RGB')
112
+ return np.array(img)
113
+
114
+
115
+ def content_text(text: str) -> dict:
116
+ return {"type": "text", "text": text}
117
+
118
+
119
+ def content_image_b64(b64_png: str) -> dict:
120
+ return {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_png}"}}
121
+
122
+
123
+ def build_messages(prompt: str, frame_history_b64: Deque[str], current_b64: str, legal_actions: List[int]) -> List[dict]:
124
+ messages: List[dict] = [
125
+ {"role": "system", "content": [content_text(prompt)]}
126
+ ]
127
+ if len(frame_history_b64) > 1:
128
+ total = len(frame_history_b64)
129
+ messages.extend([
130
+ {
131
+ "role": "user",
132
+ "content": [
133
+ content_text(f"Frame -{total - idx}"),
134
+ content_image_b64(_img),
135
+ ],
136
+ }
137
+ for idx, _img in enumerate(list(frame_history_b64)[:-1])
138
+ ])
139
+ messages.append({
140
+ "role": "user",
141
+ "content": [content_text("Current frame:"), content_image_b64(current_b64)],
142
+ })
143
+ # Include mapping of action ids to human-readable names for the model
144
+ action_pairs = ", ".join([f"{aid}:{ACTIONS_LOOKUP.get(aid, 'UNK')}" for aid in legal_actions])
145
+ messages.append({
146
+ "role": "user",
147
+ "content": [content_text(f"Legal actions (id:name): {action_pairs}. Respond with exactly one INTEGER id.")],
148
+ })
149
+ return messages
150
+
151
+
152
+ class GameSession:
153
+ """Holds environment/model state and advances one step per tick."""
154
+ def __init__(self, game: str, model_name: str, prompt_text: str):
155
+ if not API_KEY:
156
+ raise RuntimeError("Missing API_KEY for HF Router")
157
+ self.client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
158
+ self.env: Optional[AtariEnv] = None
159
+ self.model_name = model_name
160
+ self.game = game
161
+ self.prompt = (prompt_text or "").strip() or VISION_PROMPT
162
+ self.frame_history_base64: Deque[str] = deque(maxlen=FRAME_HISTORY_LENGTH)
163
+ self.total_reward = 0.0
164
+ self.steps = 0
165
+ self.done = False
166
+
167
+ # Start environment
168
+ self.env = AtariEnv(base_url=f"https://burtenshaw-{game}.hf.space")
169
+ result = self.env.reset()
170
+ self.obs = result.observation
171
+ self.log_message = f"Game: {self.game} started"
172
+
173
+ def close(self):
174
+ if self.env is not None:
175
+ try:
176
+ self.env.close()
177
+ finally:
178
+ self.env = None
179
+ self.done = True
180
+
181
+ def next_frame(self) -> Optional[np.ndarray]:
182
+ # Snapshot env reference to avoid race if another thread closes it mid-tick
183
+ env = self.env
184
+ if self.done or env is None:
185
+ return None
186
+ if self.steps >= MAX_STEPS_PER_GAME:
187
+ self.close()
188
+ return None
189
+
190
+ # Prepare images
191
+ image_data = screen_to_base64(self.obs.screen, self.obs.screen_shape)
192
+ if FRAME_HISTORY_LENGTH > 0:
193
+ self.frame_history_base64.append(image_data)
194
+
195
+ # Build messages (deduplicated helpers)
196
+ messages = build_messages(self.prompt, self.frame_history_base64, image_data, self.obs.legal_actions)
197
+
198
+ # Query model
199
+ try:
200
+ completion = self.client.chat.completions.create(
201
+ model=self.model_name,
202
+ messages=messages,
203
+ temperature=TEMPERATURE,
204
+ max_tokens=MAX_TOKENS,
205
+ )
206
+ response_text = completion.choices[0].message.content or ""
207
+ action_id = parse_action(response_text, self.obs.legal_actions)
208
+ except Exception:
209
+ action_id = 0 if 0 in self.obs.legal_actions else self.obs.legal_actions[0]
210
+
211
+ # Step env (guard against races with stop/close)
212
+ try:
213
+ result = env.step(AtariAction(action_id=action_id))
214
+ except AttributeError:
215
+ # env likely closed concurrently
216
+ self.close()
217
+ return None
218
+ except Exception:
219
+ # Network/server error - stop session gracefully
220
+ self.close()
221
+ return None
222
+ self.obs = result.observation
223
+ self.total_reward += result.reward or 0.0
224
+ self.steps += 1
225
+ if result.done:
226
+ self.done = True
227
+ self.close()
228
+
229
+ action_name = ACTIONS_LOOKUP.get(action_id, str(action_id))
230
+ self.log_message += f"\nAction: {action_name} ({action_id}) Reward: {result.reward}"
231
+ return screen_to_numpy(self.obs.screen, self.obs.screen_shape)
232
+
233
+
234
+ def parse_action(text: str, legal_actions: List[int]) -> int:
235
+ """
236
+ Parse action from model output.
237
+ Handles chain-of-thought format by taking the LAST valid number found.
238
+
239
+ Args:
240
+ text: Model's text response (may include reasoning)
241
+ legal_actions: List of valid action IDs
242
+
243
+ Returns:
244
+ Selected action ID (defaults to NOOP if parsing fails)
245
+ """
246
+ # Look for single digit numbers in the response
247
+ numbers = re.findall(r'\b\d+\b', text)
248
+
249
+ # Check from the end (last number is likely the final action after reasoning)
250
+ for num_str in reversed(numbers):
251
+ action_id = int(num_str)
252
+ if action_id in legal_actions:
253
+ return action_id
254
+
255
+ # Default to NOOP if available, otherwise first legal action
256
+ return 0 if 0 in legal_actions else legal_actions[0]
257
+
258
+
259
+ # Legacy CLI loop removed; Gradio's Image.every drives stepping via GameSession.next_frame
260
+
261
+
262
+ def start_session(game: str, model_name: str, prompt_text: str) -> Optional[GameSession]:
263
+ try:
264
+ return GameSession(game=game, model_name=model_name, prompt_text=prompt_text)
265
+ except Exception as e:
266
+ raise gr.Error(str(e))
267
+
268
+
269
+ def stop_session(session: Optional[GameSession]) -> Optional[GameSession]:
270
+ if isinstance(session, GameSession):
271
+ session.close()
272
+ return None
273
+
274
+
275
+ def frame_tick(session: Optional[GameSession]) -> Optional[np.ndarray]:
276
+ if not isinstance(session, GameSession):
277
+ return None
278
+ frame = session.next_frame()
279
+ if frame is None:
280
+ # Auto-stop when done
281
+ session.close()
282
+ return None
283
+ return frame
284
+
285
+
286
+ def log_tick(session: Optional[GameSession]) -> str:
287
+ if not isinstance(session, GameSession):
288
+ return ""
289
+ return session.log_message
290
+
291
+
292
+ def launch_gradio_app():
293
+ games = [
294
+ "pong",
295
+ "breakout",
296
+ "pacman",
297
+ ]
298
+ models = [
299
+ "Qwen/Qwen3-VL-8B-Instruct",
300
+ "Qwen/Qwen3-VL-72B-A14B-Instruct",
301
+ "Qwen/Qwen3-VL-235B-A22B-Instruct",
302
+ ]
303
+
304
+ with gr.Blocks() as demo:
305
+ gr.Markdown("""
306
+ ### Atari Vision-Language Control
307
+ - Select a game and model, edit the prompt, then click Start.
308
+ - Frames are streamed directly from the environment without modification.
309
+ - There are a limited number of environment spaces via `"https://burtenshaw-{game}.hf.space"`
310
+ - Duplicate the space and change environment variables if you want to use a different game.
311
+ """)
312
+
313
+
314
+ session_state = gr.State()
315
+
316
+ with gr.Row():
317
+
318
+ with gr.Column():
319
+ game_dd = gr.Dropdown(choices=games, value="pong", label="Game")
320
+ model_dd = gr.Dropdown(choices=models, value=models[0], label="Model")
321
+ prompt_tb = gr.Textbox(label="Prompt", value=VISION_PROMPT, lines=6)
322
+ with gr.Row():
323
+ start_btn = gr.Button("Start", variant="primary")
324
+ stop_btn = gr.Button("Stop")
325
+
326
+ with gr.Column():
327
+ out_image = gr.Image(label="Game Stream", type="numpy", value=frame_tick, inputs=[session_state], every=0.1, height=480, width=640)
328
+
329
+ out_text = gr.Textbox(label="Game Logs", value=log_tick, inputs=[session_state], lines=10, every=0.5)
330
+
331
+ # Controls
332
+ start_btn.click(start_session, inputs=[game_dd, model_dd, prompt_tb], outputs=[session_state])
333
+ stop_btn.click(stop_session, inputs=[session_state], outputs=[session_state])
334
+
335
+ demo.queue()
336
+ demo.launch()
337
+
338
+ if __name__ == "__main__":
339
+ launch_gradio_app()