import os import json import gradio as gr import torch from huggingface_hub import hf_hub_download from tokenizers import Tokenizer as HFTokenizer from gpt_infer import GPT, GPTConfig DEFAULT_MODEL = os.environ.get('NANOCHAT_DEFAULT_MODEL', 'loocorez/nanochat-sft-d20-step650') ALL_MODELS = [ 'loocorez/nanochat-sft-d20-step650', 'loocorez/nanochat-mid-d20-step765', 'loocorez/nanochat-base-d20-step21400', ] @torch.inference_mode() def load_model(repo_id: str): cfg_path = hf_hub_download(repo_id, 'config.json') with open(cfg_path, 'r') as f: cfg = json.load(f) model_config = GPTConfig( sequence_len=cfg.get('n_ctx', 2048), vocab_size=cfg['vocab_size'], n_layer=cfg['n_layer'], n_head=cfg['n_head'], n_kv_head=cfg.get('n_kv_head', cfg['n_head']), n_embd=cfg['n_embd'], ) model = GPT(model_config) model.eval() try: from safetensors.torch import load_file weights_path = hf_hub_download(repo_id, 'model.safetensors') sd = load_file(weights_path) except Exception: weights_path = hf_hub_download(repo_id, 'pytorch_model.bin') sd = torch.load(weights_path, map_location='cpu') model.load_state_dict(sd, strict=True, assign=True) tok_path = hf_hub_download(repo_id, 'tokenizer.json') tok = HFTokenizer.from_file(tok_path) return model, tok model_cache = {} def get_model(repo_id: str): if repo_id not in model_cache: model_cache[repo_id] = load_model(repo_id) return model_cache[repo_id] @torch.inference_mode() def generate(repo_id: str, system_prompt: str, prompt: str, max_tokens: int, temperature: float, top_k: int|None): model, tok = get_model(repo_id) bos_id = tok.token_to_id('<|bos|>') # Combine system + user prompt text = prompt if not system_prompt else f"{system_prompt.strip()} {prompt}" ids = tok.encode(text).ids if bos_id is not None: ids = [bos_id] + ids out_tokens = [] for token in model.generate(ids, max_tokens=max_tokens, temperature=temperature, top_k=top_k or None): out_tokens.append(token) text = tok.decode(out_tokens, skip_special_tokens=False) return text @torch.inference_mode() def compare_three(system_prompt: str, prompt: str, max_tokens: int, temperature: float, top_k: int|None): outputs = [] for repo_id in ALL_MODELS: outputs.append(generate(repo_id, system_prompt, prompt, max_tokens, temperature, top_k)) return tuple(outputs) with gr.Blocks() as demo: gr.Markdown('# nanochat (ZeroGPU)') gr.Markdown('Run a single model or compare SFT/MID/BASE side by side.') with gr.Tabs(): with gr.Tab('Single'): repo = gr.Dropdown(choices=ALL_MODELS, value=DEFAULT_MODEL, label='Model Repo') system = gr.Textbox(label='System prompt (optional)', value='You are a helpful assistant.', lines=2) prompt = gr.Textbox(label='User prompt', lines=6) with gr.Row(): max_tokens = gr.Slider(1, 256, value=128, step=1, label='Max tokens') temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label='Temperature') top_k = gr.Slider(0, 100, value=40, step=1, label='Top-k (0=disabled)') btn = gr.Button('Generate') output = gr.Textbox(label='Output', lines=10) btn.click( fn=lambda r,s,p,m,t,k: generate(r,s,p,int(m),float(t),int(k) if int(k)>0 else None), inputs=[repo, system, prompt, max_tokens, temperature, top_k], outputs=output ) with gr.Tab('Compare 3'): system_c = gr.Textbox(label='System prompt (optional)', value='You are a helpful assistant.', lines=2) prompt_c = gr.Textbox(label='User prompt', lines=6) with gr.Row(): max_tokens_c = gr.Slider(1, 256, value=128, step=1, label='Max tokens') temperature_c = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label='Temperature') top_k_c = gr.Slider(0, 100, value=40, step=1, label='Top-k (0=disabled)') btn_c = gr.Button('Run on all three') with gr.Row(): out_sft = gr.Textbox(label='SFT', lines=10) out_mid = gr.Textbox(label='MID', lines=10) out_base = gr.Textbox(label='BASE', lines=10) btn_c.click( fn=lambda s,p,m,t,k: compare_three(s,p,int(m),float(t),int(k) if int(k)>0 else None), inputs=[system_c, prompt_c, max_tokens_c, temperature_c, top_k_c], outputs=[out_sft, out_mid, out_base] ) if __name__ == '__main__': demo.launch()