Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	File size: 11,614 Bytes
			
			| 5dff91b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 | #!/usr/bin/env python3
"""
Play any Atari game using a Vision-Language Model via the Hugging Face Router API.
The script:
1. Starts an Atari environment (Docker) for the selected game
2. Sends recent screen frames to a vision-language model
3. Parses the model's integer response into an Atari action id
4. Reports a minimal summary
Notes:
- Frames are sent raw (no overlays, cropping, or resizing)
- The model receives the legal action ids each step and must return one integer
Usage:
    export API_KEY=your_hf_token_here
    python examples/atari_pong_inference.py --game breakout --model Qwen/Qwen3-VL-8B-Instruct:novita
"""
import os
import re
import base64
import gradio as gr
from collections import deque
from io import BytesIO
from typing import Deque, List, Optional
import numpy as np
from PIL import Image
from openai import OpenAI
from envs.atari_env import AtariEnv, AtariAction
# API Configuration
# For HuggingFace: Use HF_TOKEN and set API_BASE_URL
API_BASE_URL = "https://router.huggingface.co/v1"  # Hugging Face Router endpoint
API_KEY = os.getenv("API_KEY")  # Required for Hugging Face
ATARI_ENV_BASE_URL = os.getenv("ATARI_ENV_BASE_URL")  # Optional: connect to a remote Atari env
# Vision-Language Model (Hugging Face Router compatible)
MODEL = "Qwen/Qwen3-VL-8B-Instruct:novita"
# Configuration
TEMPERATURE = 0.7
MAX_STEPS_PER_GAME = 10000
MAX_TOKENS = 16
VERBOSE = True
FRAME_HISTORY_LENGTH = 4
DISPLAY_SCALE = 3  # Scale factor for enlarging frames sent to UI
MODEL_SCALE = 3    # Scale factor for enlarging frames sent to the model
# Generic game prompt for the vision model
VISION_PROMPT = (
    "You are playing an Atari-style game. You will be given recent frames "
    "and the list of legal action ids for the current step. "
    "Respond with a single integer that is exactly one of the legal action ids. "
    "Do not include any words or punctuation — only the integer."
)
ACTIONS_LOOKUP = {
    0: "NOOP",
    1: "FIRE",
    2: "UP",
    3: "RIGHT",
    4: "LEFT",
    5: "DOWN",
    6: "UPRIGHT",
    7: "UPLEFT",
    8: "DOWNRIGHT",
    9: "DOWNLEFT",
    10: "UPFIRE",
    11: "RIGHTFIRE",
    12: "LEFTFIRE",
    13: "DOWNFIRE",
    14: "UPRIGHTFIRE",
    15: "UPLEFTFIRE",
    16: "DOWNRIGHTFIRE",
    17: "DOWNLEFTFIRE",
}
def screen_to_base64(screen: List[int], screen_shape: List[int]) -> str:
    """Convert flattened screen array to base64 encoded PNG image (no processing)."""
    screen_array = np.array(screen, dtype=np.uint8).reshape(screen_shape)
    image = Image.fromarray(screen_array)
    # Enlarge image for model input if configured
    try:
        if MODEL_SCALE and MODEL_SCALE > 1:
            image = image.resize((image.width * MODEL_SCALE, image.height * MODEL_SCALE), Image.NEAREST)
    except Exception:
        pass
    buffer = BytesIO()
    image.save(buffer, format='PNG')
    buffer.seek(0)
    return base64.b64encode(buffer.read()).decode('utf-8')
def screen_to_numpy(screen: List[int], screen_shape: List[int]) -> np.ndarray:
    """Convert flattened screen to a larger RGB numpy array for gr.Image display."""
    arr = np.array(screen, dtype=np.uint8).reshape(screen_shape)
    if len(screen_shape) == 3:
        img = Image.fromarray(arr, mode='RGB')
    else:
        img = Image.fromarray(arr, mode='L')
    # Enlarge with nearest-neighbor to preserve pixel edges
    try:
        img = img.resize((img.width * DISPLAY_SCALE, img.height * DISPLAY_SCALE), Image.NEAREST)
    except Exception:
        pass
    if img.mode != 'RGB':
        img = img.convert('RGB')
    return np.array(img)
def content_text(text: str) -> dict:
    return {"type": "text", "text": text}
def content_image_b64(b64_png: str) -> dict:
    return {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_png}"}}
def build_messages(prompt: str, frame_history_b64: Deque[str], current_b64: str, legal_actions: List[int]) -> List[dict]:
    messages: List[dict] = [
        {"role": "system", "content": [content_text(prompt)]}
    ]
    if len(frame_history_b64) > 1:
        total = len(frame_history_b64)
        messages.extend([
            {
                "role": "user",
                "content": [
                    content_text(f"Frame -{total - idx}"),
                    content_image_b64(_img),
                ],
            }
            for idx, _img in enumerate(list(frame_history_b64)[:-1])
        ])
    messages.append({
        "role": "user",
        "content": [content_text("Current frame:"), content_image_b64(current_b64)],
    })
    # Include mapping of action ids to human-readable names for the model
    action_pairs = ", ".join([f"{aid}:{ACTIONS_LOOKUP.get(aid, 'UNK')}" for aid in legal_actions])
    messages.append({
        "role": "user",
        "content": [content_text(f"Legal actions (id:name): {action_pairs}. Respond with exactly one INTEGER id.")],
    })
    return messages
class GameSession:
    """Holds environment/model state and advances one step per tick."""
    def __init__(self, game: str, model_name: str, prompt_text: str):
        if not API_KEY:
            raise RuntimeError("Missing API_KEY for HF Router")
        self.client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
        self.env: Optional[AtariEnv] = None
        self.model_name = model_name
        self.game = game
        self.prompt = (prompt_text or "").strip() or VISION_PROMPT
        self.frame_history_base64: Deque[str] = deque(maxlen=FRAME_HISTORY_LENGTH)
        self.total_reward = 0.0
        self.steps = 0
        self.done = False
        # Start environment
        self.env = AtariEnv(base_url=f"https://burtenshaw-{game}.hf.space")
        result = self.env.reset()
        self.obs = result.observation
        self.log_message = f"Game: {self.game} started"
    def close(self):
        if self.env is not None:
            try:
                self.env.close()
            finally:
                self.env = None
        self.done = True
    def next_frame(self) -> Optional[np.ndarray]:
        # Snapshot env reference to avoid race if another thread closes it mid-tick
        env = self.env
        if self.done or env is None:
            return None
        if self.steps >= MAX_STEPS_PER_GAME:
            self.close()
            return None
        # Prepare images
        image_data = screen_to_base64(self.obs.screen, self.obs.screen_shape)
        if FRAME_HISTORY_LENGTH > 0:
            self.frame_history_base64.append(image_data)
        # Build messages (deduplicated helpers)
        messages = build_messages(self.prompt, self.frame_history_base64, image_data, self.obs.legal_actions)
        # Query model
        try:
            completion = self.client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                temperature=TEMPERATURE,
                max_tokens=MAX_TOKENS,
            )
            response_text = completion.choices[0].message.content or ""
            action_id = parse_action(response_text, self.obs.legal_actions)
        except Exception:
            action_id = 0 if 0 in self.obs.legal_actions else self.obs.legal_actions[0]
        # Step env (guard against races with stop/close)
        try:
            result = env.step(AtariAction(action_id=action_id))
        except AttributeError:
            # env likely closed concurrently
            self.close()
            return None
        except Exception:
            # Network/server error - stop session gracefully
            self.close()
            return None
        self.obs = result.observation
        self.total_reward += result.reward or 0.0
        self.steps += 1
        if result.done:
            self.done = True
            self.close()
        
        action_name = ACTIONS_LOOKUP.get(action_id, str(action_id))
        self.log_message += f"\nAction: {action_name} ({action_id}) Reward: {result.reward}"
        return screen_to_numpy(self.obs.screen, self.obs.screen_shape)
def parse_action(text: str, legal_actions: List[int]) -> int:
    """
    Parse action from model output.
    Handles chain-of-thought format by taking the LAST valid number found.
    
    Args:
        text: Model's text response (may include reasoning)
        legal_actions: List of valid action IDs
    
    Returns:
        Selected action ID (defaults to NOOP if parsing fails)
    """
    # Look for single digit numbers in the response
    numbers = re.findall(r'\b\d+\b', text)
    
    # Check from the end (last number is likely the final action after reasoning)
    for num_str in reversed(numbers):
        action_id = int(num_str)
        if action_id in legal_actions:
            return action_id
    
    # Default to NOOP if available, otherwise first legal action
    return 0 if 0 in legal_actions else legal_actions[0]
# Legacy CLI loop removed; Gradio's Image.every drives stepping via GameSession.next_frame
def start_session(game: str, model_name: str, prompt_text: str) -> Optional[GameSession]:
    try:
        return GameSession(game=game, model_name=model_name, prompt_text=prompt_text)
    except Exception as e:
        raise gr.Error(str(e))
def stop_session(session: Optional[GameSession]) -> Optional[GameSession]:
    if isinstance(session, GameSession):
        session.close()
    return None
def frame_tick(session: Optional[GameSession]) -> Optional[np.ndarray]:
    if not isinstance(session, GameSession):
        return None
    frame = session.next_frame()
    if frame is None:
        # Auto-stop when done
        session.close()
        return None
    return frame
def log_tick(session: Optional[GameSession]) -> str:
    if not isinstance(session, GameSession):
        return ""
    return session.log_message
def launch_gradio_app():
    games = [
        "pong",
        "breakout",
        "pacman",
    ]
    models = [
        "Qwen/Qwen3-VL-8B-Instruct",
        "Qwen/Qwen3-VL-72B-A14B-Instruct",
        "Qwen/Qwen3-VL-235B-A22B-Instruct",
    ]
    with gr.Blocks() as demo:
        gr.Markdown("""
        ### Atari Vision-Language Control
        - Select a game and model, edit the prompt, then click Start.
        - Frames are streamed directly from the environment without modification.
        - There are a limited number of environment spaces via `"https://burtenshaw-{game}.hf.space"`
        - Duplicate the space and change environment variables if you want to use a different game.
        """)
        session_state = gr.State()
        
        with gr.Row():
        
            with gr.Column():
                game_dd = gr.Dropdown(choices=games, value="pong", label="Game")
                model_dd = gr.Dropdown(choices=models, value=models[0], label="Model")
                prompt_tb = gr.Textbox(label="Prompt", value=VISION_PROMPT, lines=6)
                with gr.Row():
                    start_btn = gr.Button("Start", variant="primary")
                    stop_btn = gr.Button("Stop")
            with gr.Column():
                out_image = gr.Image(label="Game Stream", type="numpy", value=frame_tick, inputs=[session_state], every=0.1, height=480, width=640)
        
        out_text = gr.Textbox(label="Game Logs", value=log_tick, inputs=[session_state], lines=10, every=0.5)
        
        # Controls
        start_btn.click(start_session, inputs=[game_dd, model_dd, prompt_tb], outputs=[session_state])
        stop_btn.click(stop_session, inputs=[session_state], outputs=[session_state])
    demo.queue()
    demo.launch()
if __name__ == "__main__":
    launch_gradio_app()
 | 
