Sam Dobson commited on
Commit
5aee008
·
0 Parent(s):

First commit

Browse files
Files changed (4) hide show
  1. README.md +18 -0
  2. app.py +103 -0
  3. model.py +550 -0
  4. requirements.txt +5 -0
README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Nanochat
3
+ emoji: 💬
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.49.1
8
+ app_file: app.py
9
+ pinned: false
10
+ hf_oauth: true
11
+ hf_oauth_scopes:
12
+ - inference-api
13
+ license: mit
14
+ ---
15
+
16
+ A lightweight chatbot powered by [nanochat](https://huggingface.co/sdobson/nanochat), a small GPT-based language model trained in 4 hours for $100. The model runs on CPU using PyTorch for fast, private inference.
17
+
18
+ Built with [Gradio](https://gradio.app) for the interface and [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index) for model distribution.
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio interface for nanochat model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from collections.abc import Generator
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import gradio as gr
11
+ from huggingface_hub import snapshot_download
12
+
13
+ from model import NanochatModel
14
+
15
+ MODEL_REPO = os.environ.get("MODEL_REPO", "sdobson/nanochat")
16
+ MODEL_DIR = os.environ.get("MODEL_DIR", "./model_cache")
17
+ _model: NanochatModel | None = None
18
+
19
+
20
+ def download_model() -> None:
21
+ """Download the model from Hugging Face if needed."""
22
+ model_path = Path(MODEL_DIR)
23
+ if not model_path.exists() or not any(model_path.iterdir()):
24
+ snapshot_download(
25
+ repo_id=MODEL_REPO,
26
+ local_dir=MODEL_DIR,
27
+ )
28
+
29
+
30
+ def load_model() -> None:
31
+ """Load the nanochat model."""
32
+ global _model
33
+ if _model is None:
34
+ download_model()
35
+ _model = NanochatModel(model_dir=MODEL_DIR, device="cpu")
36
+
37
+
38
+ load_model()
39
+
40
+
41
+ def respond(
42
+ message: str,
43
+ history: list[dict[str, str]],
44
+ temperature: float,
45
+ top_k: int,
46
+ ) -> Generator[str, Any, None]:
47
+ """Generate a response using the nanochat model.
48
+
49
+ Args:
50
+ message: User's input message
51
+ history: Chat history in Gradio messages format
52
+ temperature: Sampling temperature
53
+ top_k: Top-k sampling parameter
54
+
55
+ Yields:
56
+ Incrementally generated response text
57
+
58
+ """
59
+ conversation = []
60
+
61
+ for msg in history:
62
+ conversation.append(msg)
63
+
64
+ conversation.append({"role": "user", "content": message})
65
+
66
+ response = ""
67
+ for token in _model.generate(
68
+ history=conversation,
69
+ max_tokens=512,
70
+ temperature=temperature,
71
+ top_k=top_k,
72
+ ):
73
+ response += token
74
+ yield response
75
+
76
+
77
+ chatbot = gr.ChatInterface(
78
+ respond,
79
+ type="messages",
80
+ additional_inputs=[
81
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
82
+ gr.Slider(
83
+ minimum=1,
84
+ maximum=200,
85
+ value=50,
86
+ step=1,
87
+ label="Top-k sampling",
88
+ ),
89
+ ],
90
+ )
91
+
92
+ with gr.Blocks(title="nanochat") as demo:
93
+ gr.Markdown("# nanochat")
94
+ gr.Markdown("Chat with an AI trained in 4 hours for $100")
95
+ gr.Markdown(
96
+ "**Note:** If inference is slow, duplicate this space to host a copy "
97
+ "of your own - it's small enough to run on a (free) CPU instance!",
98
+ )
99
+ chatbot.render()
100
+
101
+
102
+ if __name__ == "__main__":
103
+ demo.launch()
model.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Nanochat model implementation and inference utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import math
7
+ import pickle
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import TYPE_CHECKING
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import Generator
18
+
19
+
20
+ @dataclass
21
+ class GPTConfig:
22
+ """Configuration for GPT model architecture.
23
+
24
+ Attributes:
25
+ sequence_len: Maximum sequence length
26
+ vocab_size: Size of vocabulary
27
+ n_layer: Number of transformer layers
28
+ n_head: Number of attention heads
29
+ n_kv_head: Number of key-value heads
30
+ n_embd: Embedding dimension
31
+
32
+ """
33
+
34
+ sequence_len: int = 1024
35
+ vocab_size: int = 50304
36
+ n_layer: int = 12
37
+ n_head: int = 6
38
+ n_kv_head: int = 6
39
+ n_embd: int = 768
40
+
41
+
42
+ def norm(x: torch.Tensor) -> torch.Tensor:
43
+ """Apply RMS normalization to input tensor."""
44
+ return F.rms_norm(x, (x.size(-1),))
45
+
46
+
47
+ _EXPECTED_NDIM = 4
48
+
49
+
50
+ def apply_rotary_emb(
51
+ x: torch.Tensor,
52
+ cos: torch.Tensor,
53
+ sin: torch.Tensor,
54
+ ) -> torch.Tensor:
55
+ """Apply rotary positional embeddings to input tensor.
56
+
57
+ Args:
58
+ x: Input tensor of shape (batch, seq_len, n_heads, head_dim)
59
+ cos: Cosine component of rotary embeddings
60
+ sin: Sine component of rotary embeddings
61
+
62
+ Returns:
63
+ Tensor with rotary embeddings applied
64
+
65
+ """
66
+ assert x.ndim == _EXPECTED_NDIM
67
+ d = x.shape[3] // 2
68
+ x1, x2 = x[..., :d], x[..., d:]
69
+ y1 = x1 * cos + x2 * sin
70
+ y2 = x1 * (-sin) + x2 * cos
71
+ return torch.cat([y1, y2], 3).to(x.dtype)
72
+
73
+
74
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
75
+ """Repeat key/value tensors for multi-head attention.
76
+
77
+ Args:
78
+ x: Input tensor of shape (batch, n_kv_heads, seq_len, head_dim)
79
+ n_rep: Number of times to repeat
80
+
81
+ Returns:
82
+ Tensor with repeated key/value heads
83
+
84
+ """
85
+ if n_rep == 1:
86
+ return x
87
+ bs, n_kv_heads, slen, head_dim = x.shape
88
+ return (
89
+ x[:, :, None, :, :]
90
+ .expand(bs, n_kv_heads, n_rep, slen, head_dim)
91
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
92
+ )
93
+
94
+
95
+ class CausalSelfAttention(nn.Module):
96
+ """Causal self-attention with rotary position embeddings."""
97
+
98
+ def __init__(self, config: GPTConfig, layer_idx: int) -> None:
99
+ """Initialize attention layer.
100
+
101
+ Args:
102
+ config: Model configuration
103
+ layer_idx: Layer index for KV cache
104
+
105
+ """
106
+ super().__init__()
107
+ self.layer_idx = layer_idx
108
+ self.n_head = config.n_head
109
+ self.n_kv_head = config.n_kv_head
110
+ self.n_embd = config.n_embd
111
+ self.head_dim = self.n_embd // self.n_head
112
+ assert self.n_embd % self.n_head == 0
113
+ assert self.n_kv_head <= self.n_head
114
+ assert self.n_head % self.n_kv_head == 0
115
+ self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
116
+ self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
117
+ self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
118
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
119
+
120
+ def forward(
121
+ self,
122
+ x: torch.Tensor,
123
+ cos_sin: tuple[torch.Tensor, torch.Tensor],
124
+ kv_cache: object | None,
125
+ ) -> torch.Tensor:
126
+ """Forward pass of attention layer.
127
+
128
+ Args:
129
+ x: Input tensor
130
+ cos_sin: Tuple of (cos, sin) rotary embeddings
131
+ kv_cache: Optional KV cache for generation
132
+
133
+ Returns:
134
+ Output tensor after attention
135
+
136
+ """
137
+ b, t, _c = x.size()
138
+ q = self.c_q(x).view(b, t, self.n_head, self.head_dim)
139
+ k = self.c_k(x).view(b, t, self.n_kv_head, self.head_dim)
140
+ v = self.c_v(x).view(b, t, self.n_kv_head, self.head_dim)
141
+ cos, sin = cos_sin
142
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
143
+ q, k = norm(q), norm(k)
144
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
145
+ if kv_cache is not None:
146
+ k, v = kv_cache.insert_kv(self.layer_idx, k, v)
147
+ tq = q.size(2)
148
+ tk = k.size(2)
149
+ nrep = self.n_head // self.n_kv_head
150
+ k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
151
+ if kv_cache is None or tq == tk:
152
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
153
+ elif tq == 1:
154
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
155
+ else:
156
+ attn_mask = torch.zeros((tq, tk), dtype=torch.bool, device=q.device)
157
+ prefix_len = tk - tq
158
+ if prefix_len > 0:
159
+ attn_mask[:, :prefix_len] = True
160
+ attn_mask[:, prefix_len:] = torch.tril(
161
+ torch.ones((tq, tq), dtype=torch.bool, device=q.device),
162
+ )
163
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
164
+ y = y.transpose(1, 2).contiguous().view(b, t, -1)
165
+ return self.c_proj(y)
166
+
167
+
168
+ class MLP(nn.Module):
169
+ """Multi-layer perceptron with squared ReLU activation."""
170
+
171
+ def __init__(self, config: GPTConfig) -> None:
172
+ """Initialize MLP layer.
173
+
174
+ Args:
175
+ config: Model configuration
176
+
177
+ """
178
+ super().__init__()
179
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
180
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
181
+
182
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
183
+ """Forward pass of MLP.
184
+
185
+ Args:
186
+ x: Input tensor
187
+
188
+ Returns:
189
+ Output tensor after MLP transformation
190
+
191
+ """
192
+ x = self.c_fc(x)
193
+ x = F.relu(x).square()
194
+ return self.c_proj(x)
195
+
196
+
197
+ class Block(nn.Module):
198
+ """Transformer block with attention and MLP."""
199
+
200
+ def __init__(self, config: GPTConfig, layer_idx: int) -> None:
201
+ """Initialize transformer block.
202
+
203
+ Args:
204
+ config: Model configuration
205
+ layer_idx: Layer index
206
+
207
+ """
208
+ super().__init__()
209
+ self.attn = CausalSelfAttention(config, layer_idx)
210
+ self.mlp = MLP(config)
211
+
212
+ def forward(
213
+ self,
214
+ x: torch.Tensor,
215
+ cos_sin: tuple[torch.Tensor, torch.Tensor],
216
+ kv_cache: object | None,
217
+ ) -> torch.Tensor:
218
+ """Forward pass of transformer block.
219
+
220
+ Args:
221
+ x: Input tensor
222
+ cos_sin: Tuple of (cos, sin) rotary embeddings
223
+ kv_cache: Optional KV cache for generation
224
+
225
+ Returns:
226
+ Output tensor after block transformation
227
+
228
+ """
229
+ x = x + self.attn(norm(x), cos_sin, kv_cache)
230
+ return x + self.mlp(norm(x))
231
+
232
+
233
+ class GPT(nn.Module):
234
+ """GPT model with rotary position embeddings."""
235
+
236
+ def __init__(self, config: GPTConfig) -> None:
237
+ """Initialize GPT model.
238
+
239
+ Args:
240
+ config: Model configuration
241
+
242
+ """
243
+ super().__init__()
244
+ self.config = config
245
+ self.transformer = nn.ModuleDict(
246
+ {
247
+ "wte": nn.Embedding(config.vocab_size, config.n_embd),
248
+ "h": nn.ModuleList(
249
+ [Block(config, layer_idx) for layer_idx in range(config.n_layer)],
250
+ ),
251
+ },
252
+ )
253
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
254
+ self.rotary_seq_len = config.sequence_len * 10
255
+ head_dim = config.n_embd // config.n_head
256
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
257
+ self.register_buffer("cos", cos, persistent=False)
258
+ self.register_buffer("sin", sin, persistent=False)
259
+ self.transformer.wte.to(dtype=torch.bfloat16)
260
+
261
+ def init_weights(self) -> None:
262
+ """Initialize model weights."""
263
+ self.apply(self._init_weights)
264
+ torch.nn.init.zeros_(self.lm_head.weight)
265
+ for block in self.transformer.h:
266
+ torch.nn.init.zeros_(block.mlp.c_proj.weight)
267
+ torch.nn.init.zeros_(block.attn.c_proj.weight)
268
+ head_dim = self.config.n_embd // self.config.n_head
269
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
270
+ self.cos, self.sin = cos, sin
271
+
272
+ def _init_weights(self, module: nn.Module) -> None:
273
+ """Initialize weights for a single module.
274
+
275
+ Args:
276
+ module: Module to initialize
277
+
278
+ """
279
+ if isinstance(module, nn.Linear):
280
+ fan_out = module.weight.size(0)
281
+ fan_in = module.weight.size(1)
282
+ std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
283
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
284
+ if module.bias is not None:
285
+ torch.nn.init.zeros_(module.bias)
286
+ elif isinstance(module, nn.Embedding):
287
+ torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
288
+
289
+ def _precompute_rotary_embeddings(
290
+ self,
291
+ seq_len: int,
292
+ head_dim: int,
293
+ base: int = 10000,
294
+ device: torch.device | str | None = None,
295
+ ) -> tuple[torch.Tensor, torch.Tensor]:
296
+ """Precompute rotary position embeddings.
297
+
298
+ Args:
299
+ seq_len: Maximum sequence length
300
+ head_dim: Dimension of attention heads
301
+ base: Base for frequency calculation
302
+ device: Device to place tensors on
303
+
304
+ Returns:
305
+ Tuple of (cos, sin) tensors for rotary embeddings
306
+
307
+ """
308
+ if device is None:
309
+ device = self.transformer.wte.weight.device
310
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
311
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
312
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
313
+ freqs = torch.outer(t, inv_freq)
314
+ cos, sin = freqs.cos(), freqs.sin()
315
+ cos, sin = cos.bfloat16(), sin.bfloat16()
316
+ return cos[None, :, None, :], sin[None, :, None, :]
317
+
318
+ def forward(
319
+ self,
320
+ idx: torch.Tensor,
321
+ targets: torch.Tensor | None = None,
322
+ kv_cache: object | None = None,
323
+ ) -> torch.Tensor:
324
+ """Forward pass of GPT model.
325
+
326
+ Args:
327
+ idx: Input token indices
328
+ targets: Target token indices (unused in this implementation)
329
+ kv_cache: Optional KV cache for generation
330
+
331
+ Returns:
332
+ Logits for next token prediction
333
+
334
+ """
335
+ _b, t = idx.size()
336
+ assert self.cos.size(1) >= t
337
+ t0 = 0 if kv_cache is None else kv_cache.get_pos()
338
+ cos_sin = self.cos[:, t0 : t0 + t], self.sin[:, t0 : t0 + t]
339
+ x = self.transformer.wte(idx)
340
+ x = norm(x)
341
+ for block in self.transformer.h:
342
+ x = block(x, cos_sin, kv_cache)
343
+ x = norm(x)
344
+ softcap = 15
345
+ logits = self.lm_head(x)
346
+ return softcap * torch.tanh(logits / softcap)
347
+
348
+
349
+ class NanochatModel:
350
+ """Wrapper class for loading and running inference with the nanochat model."""
351
+
352
+ def __init__(self, model_dir: str, device: str = "cpu") -> None:
353
+ """Initialize the NanochatModel.
354
+
355
+ Args:
356
+ model_dir: Directory containing model files
357
+ device: Device to run inference on (default: "cpu")
358
+
359
+ """
360
+ self.device = torch.device(device)
361
+ self.model_dir = model_dir
362
+
363
+ self.model = self._load_model()
364
+ self.enc = self._load_tokenizer()
365
+ self._setup_special_tokens()
366
+
367
+ def _load_model(self) -> GPT:
368
+ """Load the model from the model directory."""
369
+ model_dir_path = Path(self.model_dir)
370
+ model_files = list(model_dir_path.glob("model_*.pt"))
371
+ if not model_files:
372
+ msg = f"No model files found in {self.model_dir}"
373
+ raise FileNotFoundError(msg)
374
+ model_file = model_files[0]
375
+
376
+ meta_files = list(model_dir_path.glob("meta_*.json"))
377
+ if not meta_files:
378
+ msg = f"No meta files found in {self.model_dir}"
379
+ raise FileNotFoundError(msg)
380
+ meta_file = meta_files[0]
381
+
382
+ with meta_file.open() as f:
383
+ meta = json.load(f)
384
+
385
+ model_config_kwargs = meta["model_config"]
386
+
387
+ model_config = GPTConfig(**model_config_kwargs)
388
+ with torch.device("meta"):
389
+ model = GPT(model_config)
390
+
391
+ model_data = torch.load(
392
+ model_file,
393
+ map_location=self.device,
394
+ weights_only=True,
395
+ )
396
+ model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
397
+
398
+ model_data = {
399
+ k: v.float() if v.dtype == torch.bfloat16 else v
400
+ for k, v in model_data.items()
401
+ }
402
+
403
+ model.to_empty(device=self.device)
404
+ model.init_weights()
405
+ model.load_state_dict(model_data, strict=True, assign=True)
406
+ model.eval()
407
+
408
+ return model
409
+
410
+ def _load_tokenizer(self) -> object:
411
+ """Load the tokenizer from the model directory.
412
+
413
+ Returns:
414
+ Loaded tokenizer object
415
+
416
+ """
417
+ tokenizer_path = Path(self.model_dir) / "tokenizer.pkl"
418
+ if not tokenizer_path.exists():
419
+ msg = f"Tokenizer not found at {tokenizer_path}"
420
+ raise FileNotFoundError(msg)
421
+
422
+ with tokenizer_path.open("rb") as f:
423
+ return pickle.load(f)
424
+
425
+ def _setup_special_tokens(self) -> None:
426
+ """Set up special token IDs for chat formatting."""
427
+ try:
428
+ try:
429
+ self.bos_token_id = self.enc.encode_single_token("<|bos|>")
430
+ except KeyError:
431
+ self.bos_token_id = self.enc.encode_single_token("<|endoftext|>")
432
+
433
+ self.user_start_id = self.enc.encode_single_token("<|user_start|>")
434
+ self.user_end_id = self.enc.encode_single_token("<|user_end|>")
435
+ self.assistant_start_id = self.enc.encode_single_token(
436
+ "<|assistant_start|>",
437
+ )
438
+ self.assistant_end_id = self.enc.encode_single_token("<|assistant_end|>")
439
+ self.stop_tokens = {self.bos_token_id, self.assistant_end_id}
440
+ except KeyError as e:
441
+ msg = f"Required special token missing from tokenizer: {e}"
442
+ raise ValueError(msg) from e
443
+
444
+ def format_prompt(self, message: str) -> list[int]:
445
+ """Format a user message using chat format.
446
+
447
+ Args:
448
+ message: User's input message
449
+
450
+ Returns:
451
+ List of token IDs formatted for chat
452
+
453
+ """
454
+ prompt_tokens = self.enc.encode_ordinary(message)
455
+ return [
456
+ self.bos_token_id,
457
+ self.user_start_id,
458
+ *prompt_tokens,
459
+ self.user_end_id,
460
+ self.assistant_start_id,
461
+ ]
462
+
463
+ def format_conversation(self, history: list[dict[str, str]]) -> list[int]:
464
+ """Format a multi-turn conversation using chat format.
465
+
466
+ Args:
467
+ history: List of message dictionaries with 'role' and 'content' keys
468
+ role can be 'user' or 'assistant'
469
+
470
+ Returns:
471
+ List of token IDs formatted for multi-turn chat
472
+
473
+ """
474
+ tokens = [self.bos_token_id]
475
+
476
+ for message in history:
477
+ role = message.get("role")
478
+ content = message.get("content", "")
479
+ content_tokens = self.enc.encode_ordinary(content)
480
+
481
+ if role == "user":
482
+ tokens.extend([
483
+ self.user_start_id,
484
+ *content_tokens,
485
+ self.user_end_id,
486
+ ])
487
+ elif role == "assistant":
488
+ tokens.extend([
489
+ self.assistant_start_id,
490
+ *content_tokens,
491
+ self.assistant_end_id,
492
+ ])
493
+
494
+ tokens.append(self.assistant_start_id)
495
+
496
+ return tokens
497
+
498
+ def generate(
499
+ self,
500
+ prompt: str | None = None,
501
+ history: list[dict[str, str]] | None = None,
502
+ max_tokens: int = 512,
503
+ temperature: float = 0.8,
504
+ top_k: int = 50,
505
+ ) -> Generator[str, None, None]:
506
+ """Generate text from a prompt or conversation history.
507
+
508
+ Args:
509
+ prompt: The input text prompt (for single-turn)
510
+ history: List of message dicts with 'role' and 'content' (for multi-turn)
511
+ max_tokens: Maximum number of tokens to generate
512
+ temperature: Sampling temperature
513
+ top_k: Top-k sampling parameter
514
+
515
+ Yields:
516
+ Decoded token strings
517
+
518
+ """
519
+ if history is not None:
520
+ input_ids = self.format_conversation(history)
521
+ elif prompt is not None:
522
+ input_ids = self.format_prompt(prompt)
523
+ else:
524
+ msg = "Either prompt or history must be provided"
525
+ raise ValueError(msg)
526
+
527
+ x = torch.tensor([input_ids], dtype=torch.long, device=self.device)
528
+
529
+ with torch.inference_mode():
530
+ for _ in range(max_tokens):
531
+ logits = self.model(x)
532
+
533
+ logits = logits[:, -1, :]
534
+
535
+ logits = logits / temperature
536
+
537
+ if top_k > 0:
538
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
539
+ logits[logits < v[:, [-1]]] = -float("inf")
540
+
541
+ probs = F.softmax(logits, dim=-1)
542
+ next_token = torch.multinomial(probs, num_samples=1)
543
+
544
+ if next_token.item() in self.stop_tokens:
545
+ break
546
+
547
+ token_str = self.enc.decode([next_token.item()])
548
+ yield token_str
549
+
550
+ x = torch.cat([x, next_token], dim=1)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==5.49.1
2
+ torch==2.8.0
3
+ tiktoken==0.12.0
4
+ numpy==2.2.6
5
+ huggingface_hub==0.35.3