loocorez's picture
Upload folder using huggingface_hub
c29ec17 verified
raw
history blame
2.89 kB
import os
import json
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer as HFTokenizer
from gpt_infer import GPT, GPTConfig
DEFAULT_MODEL = os.environ.get('NANOCHAT_DEFAULT_MODEL', 'loocorez/nanochat-sft-d20-step650')
def load_model(repo_id: str):
cfg_path = hf_hub_download(repo_id, 'config.json')
with open(cfg_path, 'r') as f:
cfg = json.load(f)
model_config = GPTConfig(
sequence_len=cfg.get('n_ctx', 2048),
vocab_size=cfg['vocab_size'],
n_layer=cfg['n_layer'],
n_head=cfg['n_head'],
n_kv_head=cfg.get('n_kv_head', cfg['n_head']),
n_embd=cfg['n_embd'],
)
model = GPT(model_config)
model.eval()
try:
from safetensors.torch import load_file
weights_path = hf_hub_download(repo_id, 'model.safetensors')
sd = load_file(weights_path)
except Exception:
weights_path = hf_hub_download(repo_id, 'pytorch_model.bin')
sd = torch.load(weights_path, map_location='cpu')
model.load_state_dict(sd, strict=True, assign=True)
tok_path = hf_hub_download(repo_id, 'tokenizer.json')
tok = HFTokenizer.from_file(tok_path)
return model, tok
model_cache = {}
def get_model(repo_id: str):
if repo_id not in model_cache:
model_cache[repo_id] = load_model(repo_id)
return model_cache[repo_id]
@torch.inference_mode()
def generate(repo_id: str, prompt: str, max_tokens: int, temperature: float, top_k: int|None):
model, tok = get_model(repo_id)
bos_id = tok.token_to_id('<|bos|>')
ids = tok.encode(prompt).ids
if bos_id is not None:
ids = [bos_id] + ids
out_tokens = []
for token in model.generate(ids, max_tokens=max_tokens, temperature=temperature, top_k=top_k or None):
out_tokens.append(token)
text = tok.decode(out_tokens, skip_special_tokens=False)
return text
with gr.Blocks() as demo:
gr.Markdown('# nanochat (ZeroGPU)')
gr.Markdown('Select a model and generate text.')
repo = gr.Dropdown(choices=[
'loocorez/nanochat-sft-d20-step650',
'loocorez/nanochat-mid-d20-step765',
'loocorez/nanochat-base-d20-step21400',
], value=DEFAULT_MODEL, label='Model Repo')
prompt = gr.Textbox(label='Prompt', lines=6)
with gr.Row():
max_tokens = gr.Slider(1, 256, value=128, step=1, label='Max tokens')
temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label='Temperature')
top_k = gr.Slider(0, 100, value=40, step=1, label='Top-k (0=disabled)')
btn = gr.Button('Generate')
output = gr.Textbox(label='Output', lines=10)
btn.click(fn=lambda r,p,m,t,k: generate(r,p,int(m),float(t),int(k) if int(k)>0 else None),
inputs=[repo, prompt, max_tokens, temperature, top_k],
outputs=output)
if __name__ == '__main__':
demo.launch()