Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import streamlit as st | |
| import torch | |
| from dotenv import load_dotenv | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from hangman import guess_letter | |
| from hf_utils import query_hint, query_word | |
| CONFIGS_PATH = "configs.yaml" | |
| MAX_TRIES = 6 | |
| CATEGORIES = ["Country", "Animal", "Food", "Movie"] | |
| configs = { | |
| "os_model": "google/gemma-2b-it", | |
| "device": "cpu", | |
| "generation_config": { | |
| "max_output_tokens": 128, | |
| "temperature": 1, | |
| "top_p": 1, | |
| "top_k": 4, | |
| }, | |
| } | |
| def setup(model_id: str, device: str) -> None: | |
| """Initializes the model and tokenizer. | |
| Args: | |
| model_id (str): Model ID used to load the tokenizer and model. | |
| """ | |
| logger.info(f"Loading model and tokenizer from model: '{model_id}'") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| token=os.environ["HF_ACCESS_TOKEN"], | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| token=os.environ["HF_ACCESS_TOKEN"], | |
| ).to(device) | |
| logger.info("Setup finished") | |
| return {"tokenizer": tokenizer, "model": model} | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__file__) | |
| st.set_page_config( | |
| page_title="Gemma Hangman", | |
| page_icon="🧩", | |
| ) | |
| load_dotenv() | |
| assets = setup(configs["os_model"], configs["device"]) | |
| tokenizer = assets["tokenizer"] | |
| model = assets["model"] | |
| if not st.session_state: | |
| st.session_state["word"] = "" | |
| st.session_state["hint"] = "" | |
| st.session_state["hangman"] = "" | |
| st.session_state["missed_letters"] = [] | |
| st.session_state["correct_letters"] = [] | |
| st.title("Gemini Hangman") | |
| st.markdown("## Guess the word based on a hint") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| category = st.selectbox( | |
| "Choose a category", | |
| CATEGORIES, | |
| ) | |
| with col2: | |
| start_btn = st.button("Start game") | |
| reset_btn = st.button("Reset game") | |
| if start_btn: | |
| st.session_state["word"] = query_word( | |
| category, | |
| model, | |
| tokenizer, | |
| configs["generation_config"], | |
| configs["device"], | |
| ) | |
| st.session_state["hint"] = query_hint( | |
| st.session_state["word"], | |
| model, | |
| tokenizer, | |
| configs["generation_config"], | |
| configs["device"], | |
| ) | |
| st.session_state["hangman"] = "_" * len(st.session_state["word"]) | |
| st.session_state["missed_letters"] = [] | |
| st.session_state["correct_letters"] = [] | |
| if reset_btn: | |
| st.session_state["word"] = "" | |
| st.session_state["hint"] = "" | |
| st.session_state["hangman"] = "" | |
| st.session_state["missed_letters"] = [] | |
| st.session_state["correct_letters"] = [] | |
| st.markdown( | |
| """ | |
| Note: you must input whitespaces and special characters. | |
| """ | |
| ) | |
| st.markdown(f'### Hint:\n{st.session_state["hint"]}') | |
| col3, col4 = st.columns(2) | |
| with col3: | |
| guess = st.text_input(label="Enter letter") | |
| guess_btn = st.button("Guess letter") | |
| if guess_btn: | |
| st.session_state = guess_letter(guess, st.session_state) | |
| with col4: | |
| hangman = st.text_input( | |
| label="Hangman", | |
| value=st.session_state["hangman"], | |
| ) | |
| st.text_input( | |
| label=f"Missed letters (max {MAX_TRIES} tries)", | |
| value=", ".join(st.session_state["missed_letters"]), | |
| ) | |
| if st.session_state["word"] == st.session_state["hangman"] != "": | |
| st.success("You won!") | |
| st.balloons() | |
| if len(st.session_state["missed_letters"]) >= MAX_TRIES: | |
| st.error(f"""You lost, the correct word was '{st.session_state["word"]}'""") | |
| st.snow() | |