File size: 2,131 Bytes
85c8f79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
os.environ.setdefault("HF_HOME", "/tmp/hf")
os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf/hub")
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers")

from transformers import AutoModel
from huggingface_hub import hf_hub_download
import torch
import gradio as gr
import pickle

MODEL_ID = "loocorez/nanochat-sft-d20-test"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model via Auto* with trust_remote_code
model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True)
model = model.to(device)
model.eval()

# Load tokenizer.pkl directly (avoid AutoTokenizer mapping issues)
tok_path = hf_hub_download(MODEL_ID, filename="tokenizer.pkl")

class PklTokenizer:
    def __init__(self, pkl_file):
        with open(pkl_file, "rb") as f:
            self.enc = pickle.load(f)
        self._bos = self.enc.encode_single_token("<|bos|>")
    def get_bos_token_id(self):
        return self._bos
    def encode(self, text, prepend=None):
        ids = self.enc.encode_ordinary(text)
        if prepend is not None:
            ids = [prepend] + ids
        return ids
    def decode(self, ids):
        return self.enc.decode(ids)

tokenizer = PklTokenizer(tok_path)

def complete(prompt, max_new_tokens=64):
    input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id())
    ids = torch.tensor([input_ids], dtype=torch.long, device=device)
    with torch.inference_mode():
        for _ in range(max_new_tokens):
            outputs = model(input_ids=ids)
            logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits
            next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
            ids = torch.cat([ids, next_token], dim=1)
    return tokenizer.decode(ids[0].tolist())

with gr.Blocks() as demo:
    gr.Markdown("# NanoChat Transformers Demo (SFT d20)")
    inp = gr.Textbox(value="The capital of Belgium is ")
    max_toks = gr.Slider(1, 256, value=64, step=1, label="Max new tokens")
    out = gr.Textbox()
    btn = gr.Button("Generate")
    btn.click(complete, [inp, max_toks], [out])

demo.launch()