"""Gradio interface for nanochat model.""" from __future__ import annotations import os from collections.abc import Generator from pathlib import Path from typing import Any import gradio as gr from huggingface_hub import snapshot_download from model import NanochatModel MODEL_REPO = os.environ.get("MODEL_REPO", "sdobson/nanochat") MODEL_DIR = os.environ.get("MODEL_DIR", "./model_cache") _model: NanochatModel | None = None def download_model() -> None: """Download the model from Hugging Face if needed.""" model_path = Path(MODEL_DIR) if not model_path.exists() or not any(model_path.iterdir()): snapshot_download( repo_id=MODEL_REPO, local_dir=MODEL_DIR, ) def load_model() -> None: """Load the nanochat model.""" global _model if _model is None: download_model() _model = NanochatModel(model_dir=MODEL_DIR, device="cpu") load_model() def respond( message: str, history: list[dict[str, str]], temperature: float, top_k: int, ) -> Generator[str, Any, None]: """Generate a response using the nanochat model. Args: message: User's input message history: Chat history in Gradio messages format temperature: Sampling temperature top_k: Top-k sampling parameter Yields: Incrementally generated response text """ conversation = [] for msg in history: conversation.append(msg) conversation.append({"role": "user", "content": message}) response = "" for token in _model.generate( history=conversation, max_tokens=512, temperature=temperature, top_k=top_k, ): response += token yield response chatbot = gr.ChatInterface( respond, type="messages", additional_inputs=[ gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), gr.Slider( minimum=1, maximum=200, value=50, step=1, label="Top-k sampling", ), ], ) with gr.Blocks(title="nanochat") as demo: gr.Markdown("# nanochat") gr.Markdown("Chat with an AI trained in 4 hours for $100") gr.Markdown( "**Note:** If inference is slow, duplicate this space to host a copy " "of your own - it's small enough to run on a (free) CPU instance!", ) chatbot.render() if __name__ == "__main__": demo.launch()