from __future__ import annotations import asyncio import os import re from typing import AsyncIterator, Dict, List import gradio as gr from gradio_wordleboard import WordleBoard from openai import AsyncOpenAI from envs.textarena_env import TextArenaAction, TextArenaEnv from envs.textarena_env.models import TextArenaMessage API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") MODEL = os.getenv("MODEL", "openai/gpt-oss-120b:novita") MAX_TURNS = int(os.getenv("MAX_TURNS", "6")) DOCKER_IMAGE = os.getenv("TEXTARENA_IMAGE", "textarena-env:latest") def _format_history(messages: List[TextArenaMessage]) -> str: lines: List[str] = [] for message in messages: tag = message.category or "MESSAGE" lines.append(f"[{tag}] {message.content}") return "\n".join(lines) def _make_user_prompt(prompt_text: str, messages: List[TextArenaMessage]) -> str: history = _format_history(messages) return ( f"Current prompt:\n{prompt_text}\n\n" f"Conversation so far:\n{history}\n\n" "Reply with your next guess enclosed in square brackets." ) async def _generate_guesses(client: AsyncOpenAI, prompt: str, history: List[TextArenaMessage]) -> str: response = await client.chat.completions.create( model=MODEL, messages=[ { "role": "system", "content": ( "You are an expert Wordle solver." " Always respond with a single guess inside square brackets, e.g. [crane]." " Use lowercase letters, exactly one five-letter word per reply." " Reason about prior feedback before choosing the next guess." " Words must be 5 letters long and real English words." " Do not include any other text in your response." " Do not repeat the same guess twice." ), }, {"role": "user", "content": _make_user_prompt(prompt, history)}, ], max_tokens=64, temperature=0.7, ) content = response.choices[0].message.content response_text = content.strip() if content else "" print(f"Response text: {response_text}") return response_text async def _play_wordle(env: TextArenaEnv, client: AsyncOpenAI) -> AsyncIterator[Dict[str, str]]: state = await asyncio.to_thread(env.reset) observation = state.observation for turn in range(1, MAX_TURNS + 1): if state.done: break model_output = await _generate_guesses(client, observation.prompt, observation.messages) guess = _extract_guess(model_output) state = await asyncio.to_thread(env.step, TextArenaAction(message=guess)) observation = state.observation feedback = _collect_feedback(observation.messages) yield {"guess": guess, "feedback": feedback} yield { "guess": "", "feedback": _collect_feedback(observation.messages), } def _extract_guess(text: str) -> str: if not text: return "[crane]" match = re.search(r"\[([A-Za-z]{5})\]", text) if match: guess = match.group(1).lower() return f"[{guess}]" cleaned = re.sub(r"[^a-zA-Z]", "", text).lower() if len(cleaned) >= 5: return f"[{cleaned[:5]}]" return "[crane]" def _collect_feedback(messages: List[TextArenaMessage]) -> str: parts: List[str] = [] for message in messages: tag = message.category or "MESSAGE" if tag.upper() in {"FEEDBACK", "SYSTEM", "MESSAGE"}: parts.append(message.content.strip()) return "\n".join(parts).strip() async def inference_handler(api_key: str) -> AsyncIterator[str]: if not api_key: raise RuntimeError("HF_TOKEN or API_KEY environment variable must be set.") client = AsyncOpenAI(base_url=API_BASE_URL, api_key=api_key) env = TextArenaEnv(base_url="https://burtenshaw-textarena.hf.space") try: async for result in _play_wordle(env, client): yield result["feedback"] finally: env.close() wordle_component = WordleBoard() async def run_inference() -> AsyncIterator[Dict]: feedback_history: List[str] = [] async for feedback in inference_handler(API_KEY): stripped = feedback.strip() if not stripped: continue feedback_history.append(stripped) combined_feedback = "\n\n".join(feedback_history) state = wordle_component.parse_feedback(combined_feedback) yield wordle_component.to_public_dict(state) if not feedback_history: yield wordle_component.to_public_dict(wordle_component.create_game_state()) with gr.Blocks() as demo: gr.Markdown("# Wordle TextArena Inference Demo") board = WordleBoard(value=wordle_component.to_public_dict(wordle_component.create_game_state())) run_button = gr.Button("Run Inference", variant="primary") run_button.click( fn=run_inference, inputs=None, outputs=board, show_progress=True, api_name="run", ) demo.queue() if __name__ == "__main__": if not API_KEY: raise SystemExit("HF_TOKEN (or API_KEY) must be set to query the model.") demo.launch()