File size: 17,558 Bytes
629d3bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 |
"""Nanochat model implementation and inference utilities."""
from __future__ import annotations
import json
import math
import pickle
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
from torch import nn
if TYPE_CHECKING:
from collections.abc import Generator
@dataclass
class GPTConfig:
"""Configuration for GPT model architecture.
Attributes:
sequence_len: Maximum sequence length
vocab_size: Size of vocabulary
n_layer: Number of transformer layers
n_head: Number of attention heads
n_kv_head: Number of key-value heads
n_embd: Embedding dimension
"""
sequence_len: int = 1024
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 6
n_kv_head: int = 6
n_embd: int = 768
def norm(x: torch.Tensor) -> torch.Tensor:
"""Apply RMS normalization to input tensor."""
return F.rms_norm(x, (x.size(-1),))
_EXPECTED_NDIM = 4
def apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
"""Apply rotary positional embeddings to input tensor.
Args:
x: Input tensor of shape (batch, seq_len, n_heads, head_dim)
cos: Cosine component of rotary embeddings
sin: Sine component of rotary embeddings
Returns:
Tensor with rotary embeddings applied
"""
assert x.ndim == _EXPECTED_NDIM
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], 3).to(x.dtype)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Repeat key/value tensors for multi-head attention.
Args:
x: Input tensor of shape (batch, n_kv_heads, seq_len, head_dim)
n_rep: Number of times to repeat
Returns:
Tensor with repeated key/value heads
"""
if n_rep == 1:
return x
bs, n_kv_heads, slen, head_dim = x.shape
return (
x[:, :, None, :, :]
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
)
class CausalSelfAttention(nn.Module):
"""Causal self-attention with rotary position embeddings."""
def __init__(self, config: GPTConfig, layer_idx: int) -> None:
"""Initialize attention layer.
Args:
config: Model configuration
layer_idx: Layer index for KV cache
"""
super().__init__()
self.layer_idx = layer_idx
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head
assert self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(
self,
x: torch.Tensor,
cos_sin: tuple[torch.Tensor, torch.Tensor],
kv_cache: object | None,
) -> torch.Tensor:
"""Forward pass of attention layer.
Args:
x: Input tensor
cos_sin: Tuple of (cos, sin) rotary embeddings
kv_cache: Optional KV cache for generation
Returns:
Output tensor after attention
"""
b, t, _c = x.size()
q = self.c_q(x).view(b, t, self.n_head, self.head_dim)
k = self.c_k(x).view(b, t, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(b, t, self.n_kv_head, self.head_dim)
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
tq = q.size(2)
tk = k.size(2)
nrep = self.n_head // self.n_kv_head
k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
if kv_cache is None or tq == tk:
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
elif tq == 1:
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
else:
attn_mask = torch.zeros((tq, tk), dtype=torch.bool, device=q.device)
prefix_len = tk - tq
if prefix_len > 0:
attn_mask[:, :prefix_len] = True
attn_mask[:, prefix_len:] = torch.tril(
torch.ones((tq, tq), dtype=torch.bool, device=q.device),
)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
y = y.transpose(1, 2).contiguous().view(b, t, -1)
return self.c_proj(y)
class MLP(nn.Module):
"""Multi-layer perceptron with squared ReLU activation."""
def __init__(self, config: GPTConfig) -> None:
"""Initialize MLP layer.
Args:
config: Model configuration
"""
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of MLP.
Args:
x: Input tensor
Returns:
Output tensor after MLP transformation
"""
x = self.c_fc(x)
x = F.relu(x).square()
return self.c_proj(x)
class Block(nn.Module):
"""Transformer block with attention and MLP."""
def __init__(self, config: GPTConfig, layer_idx: int) -> None:
"""Initialize transformer block.
Args:
config: Model configuration
layer_idx: Layer index
"""
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(
self,
x: torch.Tensor,
cos_sin: tuple[torch.Tensor, torch.Tensor],
kv_cache: object | None,
) -> torch.Tensor:
"""Forward pass of transformer block.
Args:
x: Input tensor
cos_sin: Tuple of (cos, sin) rotary embeddings
kv_cache: Optional KV cache for generation
Returns:
Output tensor after block transformation
"""
x = x + self.attn(norm(x), cos_sin, kv_cache)
return x + self.mlp(norm(x))
class GPT(nn.Module):
"""GPT model with rotary position embeddings."""
def __init__(self, config: GPTConfig) -> None:
"""Initialize GPT model.
Args:
config: Model configuration
"""
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(
{
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList(
[Block(config, layer_idx) for layer_idx in range(config.n_layer)],
),
},
)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.rotary_seq_len = config.sequence_len * 10
head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
self.transformer.wte.to(dtype=torch.bfloat16)
def init_weights(self) -> None:
"""Initialize model weights."""
self.apply(self._init_weights)
torch.nn.init.zeros_(self.lm_head.weight)
for block in self.transformer.h:
torch.nn.init.zeros_(block.mlp.c_proj.weight)
torch.nn.init.zeros_(block.attn.c_proj.weight)
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
def _init_weights(self, module: nn.Module) -> None:
"""Initialize weights for a single module.
Args:
module: Module to initialize
"""
if isinstance(module, nn.Linear):
fan_out = module.weight.size(0)
fan_in = module.weight.size(1)
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
def _precompute_rotary_embeddings(
self,
seq_len: int,
head_dim: int,
base: int = 10000,
device: torch.device | str | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Precompute rotary position embeddings.
Args:
seq_len: Maximum sequence length
head_dim: Dimension of attention heads
base: Base for frequency calculation
device: Device to place tensors on
Returns:
Tuple of (cos, sin) tensors for rotary embeddings
"""
if device is None:
device = self.transformer.wte.weight.device
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
t = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16()
return cos[None, :, None, :], sin[None, :, None, :]
def forward(
self,
idx: torch.Tensor,
targets: torch.Tensor | None = None,
kv_cache: object | None = None,
) -> torch.Tensor:
"""Forward pass of GPT model.
Args:
idx: Input token indices
targets: Target token indices (unused in this implementation)
kv_cache: Optional KV cache for generation
Returns:
Logits for next token prediction
"""
_b, t = idx.size()
assert self.cos.size(1) >= t
t0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, t0 : t0 + t], self.sin[:, t0 : t0 + t]
x = self.transformer.wte(idx)
x = norm(x)
for block in self.transformer.h:
x = block(x, cos_sin, kv_cache)
x = norm(x)
softcap = 15
logits = self.lm_head(x)
return softcap * torch.tanh(logits / softcap)
class NanochatModel:
"""Wrapper class for loading and running inference with the nanochat model."""
def __init__(self, model_dir: str, device: str = "cpu") -> None:
"""Initialize the NanochatModel.
Args:
model_dir: Directory containing model files
device: Device to run inference on (default: "cpu")
"""
self.device = torch.device(device)
self.model_dir = model_dir
self.model = self._load_model()
self.enc = self._load_tokenizer()
self._setup_special_tokens()
def _load_model(self) -> GPT:
"""Load the model from the model directory."""
model_dir_path = Path(self.model_dir)
model_files = list(model_dir_path.glob("model_*.pt"))
if not model_files:
msg = f"No model files found in {self.model_dir}"
raise FileNotFoundError(msg)
model_file = model_files[0]
meta_files = list(model_dir_path.glob("meta_*.json"))
if not meta_files:
msg = f"No meta files found in {self.model_dir}"
raise FileNotFoundError(msg)
meta_file = meta_files[0]
with meta_file.open() as f:
meta = json.load(f)
model_config_kwargs = meta["model_config"]
model_config = GPTConfig(**model_config_kwargs)
with torch.device("meta"):
model = GPT(model_config)
model_data = torch.load(
model_file,
map_location=self.device,
weights_only=True,
)
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
model_data = {
k: v.float() if v.dtype == torch.bfloat16 else v
for k, v in model_data.items()
}
model.to_empty(device=self.device)
model.init_weights()
model.load_state_dict(model_data, strict=True, assign=True)
model.eval()
return model
def _load_tokenizer(self) -> object:
"""Load the tokenizer from the model directory.
Returns:
Loaded tokenizer object
"""
tokenizer_path = Path(self.model_dir) / "tokenizer.pkl"
if not tokenizer_path.exists():
msg = f"Tokenizer not found at {tokenizer_path}"
raise FileNotFoundError(msg)
with tokenizer_path.open("rb") as f:
return pickle.load(f)
def _setup_special_tokens(self) -> None:
"""Set up special token IDs for chat formatting."""
try:
try:
self.bos_token_id = self.enc.encode_single_token("<|bos|>")
except KeyError:
self.bos_token_id = self.enc.encode_single_token("<|endoftext|>")
self.user_start_id = self.enc.encode_single_token("<|user_start|>")
self.user_end_id = self.enc.encode_single_token("<|user_end|>")
self.assistant_start_id = self.enc.encode_single_token(
"<|assistant_start|>",
)
self.assistant_end_id = self.enc.encode_single_token("<|assistant_end|>")
self.stop_tokens = {self.bos_token_id, self.assistant_end_id}
except KeyError as e:
msg = f"Required special token missing from tokenizer: {e}"
raise ValueError(msg) from e
def format_prompt(self, message: str) -> list[int]:
"""Format a user message using chat format.
Args:
message: User's input message
Returns:
List of token IDs formatted for chat
"""
prompt_tokens = self.enc.encode_ordinary(message)
return [
self.bos_token_id,
self.user_start_id,
*prompt_tokens,
self.user_end_id,
self.assistant_start_id,
]
def format_conversation(self, history: list[dict[str, str]]) -> list[int]:
"""Format a multi-turn conversation using chat format.
Args:
history: List of message dictionaries with 'role' and 'content' keys
role can be 'user' or 'assistant'
Returns:
List of token IDs formatted for multi-turn chat
"""
tokens = [self.bos_token_id]
for message in history:
role = message.get("role")
content = message.get("content", "")
content_tokens = self.enc.encode_ordinary(content)
if role == "user":
tokens.extend([
self.user_start_id,
*content_tokens,
self.user_end_id,
])
elif role == "assistant":
tokens.extend([
self.assistant_start_id,
*content_tokens,
self.assistant_end_id,
])
tokens.append(self.assistant_start_id)
return tokens
def generate(
self,
prompt: str | None = None,
history: list[dict[str, str]] | None = None,
max_tokens: int = 512,
temperature: float = 0.8,
top_k: int = 50,
) -> Generator[str, None, None]:
"""Generate text from a prompt or conversation history.
Args:
prompt: The input text prompt (for single-turn)
history: List of message dicts with 'role' and 'content' (for multi-turn)
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
top_k: Top-k sampling parameter
Yields:
Decoded token strings
"""
if history is not None:
input_ids = self.format_conversation(history)
elif prompt is not None:
input_ids = self.format_prompt(prompt)
else:
msg = "Either prompt or history must be provided"
raise ValueError(msg)
x = torch.tensor([input_ids], dtype=torch.long, device=self.device)
with torch.inference_mode():
for _ in range(max_tokens):
logits = self.model(x)
logits = logits[:, -1, :]
logits = logits / temperature
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() in self.stop_tokens:
break
token_str = self.enc.decode([next_token.item()])
yield token_str
x = torch.cat([x, next_token], dim=1) |