Spaces:
Running
Running
| from __future__ import annotations | |
| import asyncio | |
| import os | |
| import re | |
| from typing import AsyncIterator, Dict, List | |
| import gradio as gr | |
| from gradio_wordleboard import WordleBoard | |
| from openai import AsyncOpenAI | |
| from envs.textarena_env import TextArenaAction, TextArenaEnv | |
| from envs.textarena_env.models import TextArenaMessage | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") | |
| MODEL = os.getenv("MODEL", "openai/gpt-oss-120b:novita") | |
| MAX_TURNS = int(os.getenv("MAX_TURNS", "6")) | |
| DOCKER_IMAGE = os.getenv("TEXTARENA_IMAGE", "textarena-env:latest") | |
| def _format_history(messages: List[TextArenaMessage]) -> str: | |
| lines: List[str] = [] | |
| for message in messages: | |
| tag = message.category or "MESSAGE" | |
| lines.append(f"[{tag}] {message.content}") | |
| return "\n".join(lines) | |
| def _make_user_prompt(prompt_text: str, messages: List[TextArenaMessage]) -> str: | |
| history = _format_history(messages) | |
| return ( | |
| f"Current prompt:\n{prompt_text}\n\n" | |
| f"Conversation so far:\n{history}\n\n" | |
| "Reply with your next guess enclosed in square brackets." | |
| ) | |
| async def _generate_guesses(client: AsyncOpenAI, prompt: str, history: List[TextArenaMessage]) -> str: | |
| response = await client.chat.completions.create( | |
| model=MODEL, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are an expert Wordle solver." | |
| " Always respond with a single guess inside square brackets, e.g. [crane]." | |
| " Use lowercase letters, exactly one five-letter word per reply." | |
| " Reason about prior feedback before choosing the next guess." | |
| " Words must be 5 letters long and real English words." | |
| " Do not include any other text in your response." | |
| " Do not repeat the same guess twice." | |
| ), | |
| }, | |
| {"role": "user", "content": _make_user_prompt(prompt, history)}, | |
| ], | |
| max_tokens=64, | |
| temperature=0.7, | |
| ) | |
| content = response.choices[0].message.content | |
| response_text = content.strip() if content else "" | |
| print(f"Response text: {response_text}") | |
| return response_text | |
| async def _play_wordle(env: TextArenaEnv, client: AsyncOpenAI) -> AsyncIterator[Dict[str, str]]: | |
| state = await asyncio.to_thread(env.reset) | |
| observation = state.observation | |
| for turn in range(1, MAX_TURNS + 1): | |
| if state.done: | |
| break | |
| model_output = await _generate_guesses(client, observation.prompt, observation.messages) | |
| guess = _extract_guess(model_output) | |
| state = await asyncio.to_thread(env.step, TextArenaAction(message=guess)) | |
| observation = state.observation | |
| feedback = _collect_feedback(observation.messages) | |
| yield {"guess": guess, "feedback": feedback} | |
| yield { | |
| "guess": "", | |
| "feedback": _collect_feedback(observation.messages), | |
| } | |
| def _extract_guess(text: str) -> str: | |
| if not text: | |
| return "[crane]" | |
| match = re.search(r"\[([A-Za-z]{5})\]", text) | |
| if match: | |
| guess = match.group(1).lower() | |
| return f"[{guess}]" | |
| cleaned = re.sub(r"[^a-zA-Z]", "", text).lower() | |
| if len(cleaned) >= 5: | |
| return f"[{cleaned[:5]}]" | |
| return "[crane]" | |
| def _collect_feedback(messages: List[TextArenaMessage]) -> str: | |
| parts: List[str] = [] | |
| for message in messages: | |
| tag = message.category or "MESSAGE" | |
| if tag.upper() in {"FEEDBACK", "SYSTEM", "MESSAGE"}: | |
| parts.append(message.content.strip()) | |
| return "\n".join(parts).strip() | |
| async def inference_handler(api_key: str) -> AsyncIterator[str]: | |
| if not api_key: | |
| raise RuntimeError("HF_TOKEN or API_KEY environment variable must be set.") | |
| client = AsyncOpenAI(base_url=API_BASE_URL, api_key=api_key) | |
| env = TextArenaEnv(base_url="https://burtenshaw-textarena.hf.space") | |
| try: | |
| async for result in _play_wordle(env, client): | |
| yield result["feedback"] | |
| finally: | |
| env.close() | |
| wordle_component = WordleBoard() | |
| async def run_inference() -> AsyncIterator[Dict]: | |
| feedback_history: List[str] = [] | |
| async for feedback in inference_handler(API_KEY): | |
| stripped = feedback.strip() | |
| if not stripped: | |
| continue | |
| feedback_history.append(stripped) | |
| combined_feedback = "\n\n".join(feedback_history) | |
| state = wordle_component.parse_feedback(combined_feedback) | |
| yield wordle_component.to_public_dict(state) | |
| if not feedback_history: | |
| yield wordle_component.to_public_dict(wordle_component.create_game_state()) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Wordle TextArena Inference Demo") | |
| board = WordleBoard(value=wordle_component.to_public_dict(wordle_component.create_game_state())) | |
| run_button = gr.Button("Run Inference", variant="primary") | |
| run_button.click( | |
| fn=run_inference, | |
| inputs=None, | |
| outputs=board, | |
| show_progress=True, | |
| api_name="run", | |
| ) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| if not API_KEY: | |
| raise SystemExit("HF_TOKEN (or API_KEY) must be set to query the model.") | |
| demo.launch() | |