"""Nanochat model implementation and inference utilities.""" from __future__ import annotations import json import math import pickle from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING import torch import torch.nn.functional as F from torch import nn if TYPE_CHECKING: from collections.abc import Generator @dataclass class GPTConfig: """Configuration for GPT model architecture. Attributes: sequence_len: Maximum sequence length vocab_size: Size of vocabulary n_layer: Number of transformer layers n_head: Number of attention heads n_kv_head: Number of key-value heads n_embd: Embedding dimension """ sequence_len: int = 1024 vocab_size: int = 50304 n_layer: int = 12 n_head: int = 6 n_kv_head: int = 6 n_embd: int = 768 def norm(x: torch.Tensor) -> torch.Tensor: """Apply RMS normalization to input tensor.""" return F.rms_norm(x, (x.size(-1),)) _EXPECTED_NDIM = 4 def apply_rotary_emb( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: """Apply rotary positional embeddings to input tensor. Args: x: Input tensor of shape (batch, seq_len, n_heads, head_dim) cos: Cosine component of rotary embeddings sin: Sine component of rotary embeddings Returns: Tensor with rotary embeddings applied """ assert x.ndim == _EXPECTED_NDIM d = x.shape[3] // 2 x1, x2 = x[..., :d], x[..., d:] y1 = x1 * cos + x2 * sin y2 = x1 * (-sin) + x2 * cos return torch.cat([y1, y2], 3).to(x.dtype) def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """Repeat key/value tensors for multi-head attention. Args: x: Input tensor of shape (batch, n_kv_heads, seq_len, head_dim) n_rep: Number of times to repeat Returns: Tensor with repeated key/value heads """ if n_rep == 1: return x bs, n_kv_heads, slen, head_dim = x.shape return ( x[:, :, None, :, :] .expand(bs, n_kv_heads, n_rep, slen, head_dim) .reshape(bs, n_kv_heads * n_rep, slen, head_dim) ) class CausalSelfAttention(nn.Module): """Causal self-attention with rotary position embeddings.""" def __init__(self, config: GPTConfig, layer_idx: int) -> None: """Initialize attention layer. Args: config: Model configuration layer_idx: Layer index for KV cache """ super().__init__() self.layer_idx = layer_idx self.n_head = config.n_head self.n_kv_head = config.n_kv_head self.n_embd = config.n_embd self.head_dim = self.n_embd // self.n_head assert self.n_embd % self.n_head == 0 assert self.n_kv_head <= self.n_head assert self.n_head % self.n_kv_head == 0 self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) def forward( self, x: torch.Tensor, cos_sin: tuple[torch.Tensor, torch.Tensor], kv_cache: object | None, ) -> torch.Tensor: """Forward pass of attention layer. Args: x: Input tensor cos_sin: Tuple of (cos, sin) rotary embeddings kv_cache: Optional KV cache for generation Returns: Output tensor after attention """ b, t, _c = x.size() q = self.c_q(x).view(b, t, self.n_head, self.head_dim) k = self.c_k(x).view(b, t, self.n_kv_head, self.head_dim) v = self.c_v(x).view(b, t, self.n_kv_head, self.head_dim) cos, sin = cos_sin q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) q, k = norm(q), norm(k) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) if kv_cache is not None: k, v = kv_cache.insert_kv(self.layer_idx, k, v) tq = q.size(2) tk = k.size(2) nrep = self.n_head // self.n_kv_head k, v = repeat_kv(k, nrep), repeat_kv(v, nrep) if kv_cache is None or tq == tk: y = F.scaled_dot_product_attention(q, k, v, is_causal=True) elif tq == 1: y = F.scaled_dot_product_attention(q, k, v, is_causal=False) else: attn_mask = torch.zeros((tq, tk), dtype=torch.bool, device=q.device) prefix_len = tk - tq if prefix_len > 0: attn_mask[:, :prefix_len] = True attn_mask[:, prefix_len:] = torch.tril( torch.ones((tq, tq), dtype=torch.bool, device=q.device), ) y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) y = y.transpose(1, 2).contiguous().view(b, t, -1) return self.c_proj(y) class MLP(nn.Module): """Multi-layer perceptron with squared ReLU activation.""" def __init__(self, config: GPTConfig) -> None: """Initialize MLP layer. Args: config: Model configuration """ super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass of MLP. Args: x: Input tensor Returns: Output tensor after MLP transformation """ x = self.c_fc(x) x = F.relu(x).square() return self.c_proj(x) class Block(nn.Module): """Transformer block with attention and MLP.""" def __init__(self, config: GPTConfig, layer_idx: int) -> None: """Initialize transformer block. Args: config: Model configuration layer_idx: Layer index """ super().__init__() self.attn = CausalSelfAttention(config, layer_idx) self.mlp = MLP(config) def forward( self, x: torch.Tensor, cos_sin: tuple[torch.Tensor, torch.Tensor], kv_cache: object | None, ) -> torch.Tensor: """Forward pass of transformer block. Args: x: Input tensor cos_sin: Tuple of (cos, sin) rotary embeddings kv_cache: Optional KV cache for generation Returns: Output tensor after block transformation """ x = x + self.attn(norm(x), cos_sin, kv_cache) return x + self.mlp(norm(x)) class GPT(nn.Module): """GPT model with rotary position embeddings.""" def __init__(self, config: GPTConfig) -> None: """Initialize GPT model. Args: config: Model configuration """ super().__init__() self.config = config self.transformer = nn.ModuleDict( { "wte": nn.Embedding(config.vocab_size, config.n_embd), "h": nn.ModuleList( [Block(config, layer_idx) for layer_idx in range(config.n_layer)], ), }, ) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.rotary_seq_len = config.sequence_len * 10 head_dim = config.n_embd // config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False) self.register_buffer("sin", sin, persistent=False) self.transformer.wte.to(dtype=torch.bfloat16) def init_weights(self) -> None: """Initialize model weights.""" self.apply(self._init_weights) torch.nn.init.zeros_(self.lm_head.weight) for block in self.transformer.h: torch.nn.init.zeros_(block.mlp.c_proj.weight) torch.nn.init.zeros_(block.attn.c_proj.weight) head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.cos, self.sin = cos, sin def _init_weights(self, module: nn.Module) -> None: """Initialize weights for a single module. Args: module: Module to initialize """ if isinstance(module, nn.Linear): fan_out = module.weight.size(0) fan_in = module.weight.size(1) std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) torch.nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) def _precompute_rotary_embeddings( self, seq_len: int, head_dim: int, base: int = 10000, device: torch.device | str | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Precompute rotary position embeddings. Args: seq_len: Maximum sequence length head_dim: Dimension of attention heads base: Base for frequency calculation device: Device to place tensors on Returns: Tuple of (cos, sin) tensors for rotary embeddings """ if device is None: device = self.transformer.wte.weight.device channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) inv_freq = 1.0 / (base ** (channel_range / head_dim)) t = torch.arange(seq_len, dtype=torch.float32, device=device) freqs = torch.outer(t, inv_freq) cos, sin = freqs.cos(), freqs.sin() cos, sin = cos.bfloat16(), sin.bfloat16() return cos[None, :, None, :], sin[None, :, None, :] def forward( self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache: object | None = None, ) -> torch.Tensor: """Forward pass of GPT model. Args: idx: Input token indices targets: Target token indices (unused in this implementation) kv_cache: Optional KV cache for generation Returns: Logits for next token prediction """ _b, t = idx.size() assert self.cos.size(1) >= t t0 = 0 if kv_cache is None else kv_cache.get_pos() cos_sin = self.cos[:, t0 : t0 + t], self.sin[:, t0 : t0 + t] x = self.transformer.wte(idx) x = norm(x) for block in self.transformer.h: x = block(x, cos_sin, kv_cache) x = norm(x) softcap = 15 logits = self.lm_head(x) return softcap * torch.tanh(logits / softcap) class NanochatModel: """Wrapper class for loading and running inference with the nanochat model.""" def __init__(self, model_dir: str, device: str = "cpu") -> None: """Initialize the NanochatModel. Args: model_dir: Directory containing model files device: Device to run inference on (default: "cpu") """ self.device = torch.device(device) self.model_dir = model_dir self.model = self._load_model() self.enc = self._load_tokenizer() self._setup_special_tokens() def _load_model(self) -> GPT: """Load the model from the model directory.""" model_dir_path = Path(self.model_dir) model_files = list(model_dir_path.glob("model_*.pt")) if not model_files: msg = f"No model files found in {self.model_dir}" raise FileNotFoundError(msg) model_file = model_files[0] meta_files = list(model_dir_path.glob("meta_*.json")) if not meta_files: msg = f"No meta files found in {self.model_dir}" raise FileNotFoundError(msg) meta_file = meta_files[0] with meta_file.open() as f: meta = json.load(f) model_config_kwargs = meta["model_config"] model_config = GPTConfig(**model_config_kwargs) with torch.device("meta"): model = GPT(model_config) model_data = torch.load( model_file, map_location=self.device, weights_only=True, ) model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} model_data = { k: v.float() if v.dtype == torch.bfloat16 else v for k, v in model_data.items() } model.to_empty(device=self.device) model.init_weights() model.load_state_dict(model_data, strict=True, assign=True) model.eval() return model def _load_tokenizer(self) -> object: """Load the tokenizer from the model directory. Returns: Loaded tokenizer object """ tokenizer_path = Path(self.model_dir) / "tokenizer.pkl" if not tokenizer_path.exists(): msg = f"Tokenizer not found at {tokenizer_path}" raise FileNotFoundError(msg) with tokenizer_path.open("rb") as f: return pickle.load(f) def _setup_special_tokens(self) -> None: """Set up special token IDs for chat formatting.""" try: try: self.bos_token_id = self.enc.encode_single_token("<|bos|>") except KeyError: self.bos_token_id = self.enc.encode_single_token("<|endoftext|>") self.user_start_id = self.enc.encode_single_token("<|user_start|>") self.user_end_id = self.enc.encode_single_token("<|user_end|>") self.assistant_start_id = self.enc.encode_single_token( "<|assistant_start|>", ) self.assistant_end_id = self.enc.encode_single_token("<|assistant_end|>") self.stop_tokens = {self.bos_token_id, self.assistant_end_id} except KeyError as e: msg = f"Required special token missing from tokenizer: {e}" raise ValueError(msg) from e def format_prompt(self, message: str) -> list[int]: """Format a user message using chat format. Args: message: User's input message Returns: List of token IDs formatted for chat """ prompt_tokens = self.enc.encode_ordinary(message) return [ self.bos_token_id, self.user_start_id, *prompt_tokens, self.user_end_id, self.assistant_start_id, ] def format_conversation(self, history: list[dict[str, str]]) -> list[int]: """Format a multi-turn conversation using chat format. Args: history: List of message dictionaries with 'role' and 'content' keys role can be 'user' or 'assistant' Returns: List of token IDs formatted for multi-turn chat """ tokens = [self.bos_token_id] for message in history: role = message.get("role") content = message.get("content", "") content_tokens = self.enc.encode_ordinary(content) if role == "user": tokens.extend([ self.user_start_id, *content_tokens, self.user_end_id, ]) elif role == "assistant": tokens.extend([ self.assistant_start_id, *content_tokens, self.assistant_end_id, ]) tokens.append(self.assistant_start_id) return tokens def generate( self, prompt: str | None = None, history: list[dict[str, str]] | None = None, max_tokens: int = 512, temperature: float = 0.8, top_k: int = 50, ) -> Generator[str, None, None]: """Generate text from a prompt or conversation history. Args: prompt: The input text prompt (for single-turn) history: List of message dicts with 'role' and 'content' (for multi-turn) max_tokens: Maximum number of tokens to generate temperature: Sampling temperature top_k: Top-k sampling parameter Yields: Decoded token strings """ if history is not None: input_ids = self.format_conversation(history) elif prompt is not None: input_ids = self.format_prompt(prompt) else: msg = "Either prompt or history must be provided" raise ValueError(msg) x = torch.tensor([input_ids], dtype=torch.long, device=self.device) with torch.inference_mode(): for _ in range(max_tokens): logits = self.model(x) logits = logits[:, -1, :] logits = logits / temperature if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("inf") probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) if next_token.item() in self.stop_tokens: break token_str = self.enc.decode([next_token.item()]) yield token_str x = torch.cat([x, next_token], dim=1)