Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from tokenizers import Tokenizer as HFTokenizer | |
| from gpt_infer import GPT, GPTConfig | |
| DEFAULT_MODEL = os.environ.get('NANOCHAT_DEFAULT_MODEL', 'loocorez/nanochat-sft-d20-step650') | |
| def load_model(repo_id: str): | |
| cfg_path = hf_hub_download(repo_id, 'config.json') | |
| with open(cfg_path, 'r') as f: | |
| cfg = json.load(f) | |
| model_config = GPTConfig( | |
| sequence_len=cfg.get('n_ctx', 2048), | |
| vocab_size=cfg['vocab_size'], | |
| n_layer=cfg['n_layer'], | |
| n_head=cfg['n_head'], | |
| n_kv_head=cfg.get('n_kv_head', cfg['n_head']), | |
| n_embd=cfg['n_embd'], | |
| ) | |
| model = GPT(model_config) | |
| model.eval() | |
| try: | |
| from safetensors.torch import load_file | |
| weights_path = hf_hub_download(repo_id, 'model.safetensors') | |
| sd = load_file(weights_path) | |
| except Exception: | |
| weights_path = hf_hub_download(repo_id, 'pytorch_model.bin') | |
| sd = torch.load(weights_path, map_location='cpu') | |
| model.load_state_dict(sd, strict=True, assign=True) | |
| tok_path = hf_hub_download(repo_id, 'tokenizer.json') | |
| tok = HFTokenizer.from_file(tok_path) | |
| return model, tok | |
| model_cache = {} | |
| def get_model(repo_id: str): | |
| if repo_id not in model_cache: | |
| model_cache[repo_id] = load_model(repo_id) | |
| return model_cache[repo_id] | |
| def generate(repo_id: str, prompt: str, max_tokens: int, temperature: float, top_k: int|None): | |
| model, tok = get_model(repo_id) | |
| bos_id = tok.token_to_id('<|bos|>') | |
| ids = tok.encode(prompt).ids | |
| if bos_id is not None: | |
| ids = [bos_id] + ids | |
| out_tokens = [] | |
| for token in model.generate(ids, max_tokens=max_tokens, temperature=temperature, top_k=top_k or None): | |
| out_tokens.append(token) | |
| text = tok.decode(out_tokens, skip_special_tokens=False) | |
| return text | |
| with gr.Blocks() as demo: | |
| gr.Markdown('# nanochat (ZeroGPU)') | |
| gr.Markdown('Select a model and generate text.') | |
| repo = gr.Dropdown(choices=[ | |
| 'loocorez/nanochat-sft-d20-step650', | |
| 'loocorez/nanochat-mid-d20-step765', | |
| 'loocorez/nanochat-base-d20-step21400', | |
| ], value=DEFAULT_MODEL, label='Model Repo') | |
| prompt = gr.Textbox(label='Prompt', lines=6) | |
| with gr.Row(): | |
| max_tokens = gr.Slider(1, 256, value=128, step=1, label='Max tokens') | |
| temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label='Temperature') | |
| top_k = gr.Slider(0, 100, value=40, step=1, label='Top-k (0=disabled)') | |
| btn = gr.Button('Generate') | |
| output = gr.Textbox(label='Output', lines=10) | |
| btn.click(fn=lambda r,p,m,t,k: generate(r,p,int(m),float(t),int(k) if int(k)>0 else None), | |
| inputs=[repo, prompt, max_tokens, temperature, top_k], | |
| outputs=output) | |
| if __name__ == '__main__': | |
| demo.launch() | |