Spaces:
Running
on
Zero
Running
on
Zero
| # app.py (Corrected Version) | |
| import os | |
| import math | |
| import pickle | |
| import shutil | |
| import subprocess | |
| import sys | |
| import textwrap | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| # --- One-Time Setup Function --- | |
| def setup_data(): | |
| """ | |
| Checks for dataset metadata and prepares it if missing. | |
| This involves cloning a repo, running a script, and cleaning up. | |
| """ | |
| data_dir = 'shakespeare_char' | |
| meta_path = os.path.join(data_dir, 'meta.pkl') | |
| if os.path.exists(meta_path): | |
| print("Dataset metadata found. Skipping setup.") | |
| return | |
| print("Dataset metadata not found. Starting one-time setup...") | |
| print("This may take a minute...") | |
| repo_url = "https://github.com/karpathy/nanoGPT" | |
| repo_dir = "nanoGPT" | |
| try: | |
| print(f"Cloning {repo_url}...") | |
| subprocess.run(["git", "clone", repo_url], check=True, capture_output=True) | |
| source_data_dir = os.path.join(repo_dir, 'data', 'shakespeare_char') | |
| print(f"Copying data from {source_data_dir} to {data_dir}...") | |
| shutil.copytree(source_data_dir, data_dir) | |
| prepare_script_path = os.path.join(data_dir, 'prepare.py') | |
| print(f"Running {prepare_script_path} to generate metadata...") | |
| subprocess.run([sys.executable, prepare_script_path], check=True, capture_output=True) | |
| print("Setup successful. 'meta.pkl' has been created.") | |
| except subprocess.CalledProcessError as e: | |
| print(f"An error occurred during setup: {e}", file=sys.stderr) | |
| print(f"Stdout: {e.stdout.decode()}", file=sys.stderr) | |
| print(f"Stderr: {e.stderr.decode()}", file=sys.stderr) | |
| sys.exit("Setup failed. Please check your git installation and internet connection.") | |
| except Exception as e: | |
| print(f"An unexpected error occurred: {e}", file=sys.stderr) | |
| sys.exit("Setup failed.") | |
| finally: | |
| if os.path.exists(repo_dir): | |
| print(f"Cleaning up by removing '{repo_dir}' directory...") | |
| shutil.rmtree(repo_dir) | |
| # --- Run Setup and Load Data --- | |
| setup_data() | |
| # Load metadata for character mappings | |
| data_dir = './shakespeare_char/' | |
| meta_path = os.path.join(data_dir, 'meta.pkl') | |
| with open(meta_path, 'rb') as f: | |
| meta = pickle.load(f) | |
| itos = meta['itos'] | |
| stoi = meta['stoi'] | |
| vocab_size = meta['vocab_size'] | |
| CONTEXT_LENGTH = 256 | |
| def decode(indices_tensor: torch.Tensor): | |
| if indices_tensor.dim() == 2: | |
| indices_tensor = indices_tensor[0] | |
| indices = indices_tensor.cpu().numpy() | |
| return ''.join([itos[i] for i in indices]) | |
| def wrap_text(long_text, width=80): | |
| paragraphs = long_text.splitlines() | |
| wrapped = [textwrap.fill(p, width=width) if p else '' for p in paragraphs] | |
| return "\n".join(wrapped) | |
| # --- Model Architecture --- | |
| class MLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) | |
| self.gelu = nn.GELU() | |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward(self, x): | |
| return self.dropout(self.c_proj(self.gelu(self.c_fc(x)))) | |
| class SelfAttention(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| assert config.n_embd % config.n_head == 0 | |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) | |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) | |
| self.attn_dropout = nn.Dropout(config.dropout) | |
| self.resid_dropout = nn.Dropout(config.dropout) | |
| self.n_head = config.n_head | |
| self.n_embd = config.n_embd | |
| self.dropout = config.dropout | |
| self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') | |
| def forward(self, x): | |
| B, T, C = x.size() | |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) | |
| k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) | |
| q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) | |
| v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) | |
| if self.flash: | |
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False) | |
| else: | |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | |
| att = F.softmax(att, dim=-1) | |
| att = self.attn_dropout(att) | |
| y = att @ v | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.resid_dropout(self.c_proj(y)) | |
| def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: | |
| return x * (1 + scale) + shift | |
| def bias_add_scale(x: torch.Tensor, bias: Optional[torch.Tensor], scale: torch.Tensor, residual: Optional[torch.Tensor]) -> torch.Tensor: | |
| out = scale * (x + bias) if bias is not None else scale * x | |
| return residual + out if residual is not None else out | |
| class DDiTBlock(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias) | |
| self.attn = SelfAttention(config) | |
| self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias) | |
| self.mlp = MLP(config) | |
| self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd, bias=True) | |
| self.adaLN_modulation.weight.data.zero_() | |
| self.adaLN_modulation.bias.data.zero_() | |
| def forward(self, x, c): | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2) | |
| x_skip = x | |
| x = modulate(self.ln_1(x), shift_msa, scale_msa) | |
| x = self.attn(x) | |
| x = bias_add_scale(x, None, gate_msa, x_skip) | |
| x = bias_add_scale(self.mlp(modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x) | |
| return x | |
| class DDitFinalLayer(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.norm_final = nn.LayerNorm(config.n_embd, bias=config.bias) | |
| self.linear = nn.Linear(config.n_embd, config.vocab_size) | |
| self.linear.weight.data.zero_() | |
| self.linear.bias.data.zero_() | |
| self.adaLN_modulation = nn.Linear(config.cond_dim, 2 * config.n_embd) | |
| self.adaLN_modulation.weight.data.zero_() | |
| self.adaLN_modulation.bias.data.zero_() | |
| def forward(self, x, c): | |
| shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2) | |
| x = modulate(self.norm_final(x), shift, scale) | |
| return self.linear(x) | |
| class TimestepEmbedder(nn.Module): | |
| def __init__(self, hidden_size, frequency_embedding_size=256): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(frequency_embedding_size, hidden_size, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, hidden_size, bias=True), | |
| ) | |
| self.frequency_embedding_size = frequency_embedding_size | |
| def timestep_embedding(t, dim, max_period=10000): | |
| half = dim // 2 | |
| freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) | |
| args = t[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |
| def forward(self, t): | |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) | |
| return self.mlp(t_freq) | |
| class GPT(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.sigma_map = TimestepEmbedder(config.cond_dim) | |
| self.transformer = nn.ModuleDict(dict( | |
| wte = nn.Embedding(config.vocab_size, config.n_embd), | |
| wpe = nn.Embedding(config.block_size, config.n_embd), | |
| drop = nn.Dropout(config.dropout), | |
| h = nn.ModuleList([DDiTBlock(config) for _ in range(config.n_layer)]), | |
| ln_f = nn.LayerNorm(config.n_embd, bias=config.bias), # <<< FIX 1: ADDED THIS LAYER | |
| )) | |
| self.lm_head = DDitFinalLayer(config) | |
| self.apply(self._init_weights) | |
| # Apply special scaled init to the residual projections | |
| for pn, p in self.named_parameters(): | |
| if pn.endswith('c_proj.weight'): | |
| torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def forward(self, idx, sigma): | |
| sigma = sigma.reshape(-1) | |
| b, t = idx.size() | |
| c = F.silu(self.sigma_map(sigma)) | |
| pos = torch.arange(0, t, dtype=torch.long, device=idx.device) | |
| tok_emb = self.transformer.wte(idx) | |
| pos_emb = self.transformer.wpe(pos) | |
| x = self.transformer.drop(tok_emb + pos_emb) | |
| for block in self.transformer.h: | |
| x = block(x, c) | |
| x = self.transformer.ln_f(x) # <<< FIX 2: CALLED THE LAYER HERE | |
| x = self.lm_head(x, c) | |
| return torch.scatter(x, -1, idx[..., None], torch.zeros_like(x[..., :1])) | |
| class GPTConfig: | |
| block_size: int = 1024 | |
| vocab_size: int = 50304 | |
| n_layer: int = 12 | |
| n_head: int = 12 | |
| n_embd: int = 768 | |
| cond_dim: int = 64 | |
| dropout: float = 0.0 | |
| bias: bool = False | |
| # --- Noise Schedule & Sampling Logic --- | |
| class GeometricNoise: | |
| def __init__(self, sigma_min=1e-4, sigma_max=20): | |
| self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max]) | |
| def total_noise(self, t): | |
| return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t | |
| def __call__(self, t): | |
| return self.total_noise(t), None # Rate not needed for sampling | |
| def transition(x_t: torch.Tensor, delta_sigma: torch.Tensor) -> torch.Tensor: | |
| base_prob = (1 - torch.exp(-delta_sigma[..., None])) / vocab_size | |
| trans = torch.ones(*x_t.shape, vocab_size, device=x_t.device) * base_prob | |
| trans = trans.scatter(-1, x_t[..., None], torch.zeros_like(trans)) | |
| diag_fill = 1 - trans.sum(dim=-1, keepdim=True) | |
| return trans.scatter(-1, x_t[..., None], diag_fill) | |
| def staggered_score(score, delta_sigma): | |
| exp_factor = torch.exp(-delta_sigma)[..., None] | |
| correction = ((exp_factor - 1) / (vocab_size * exp_factor)) * score.sum(dim=-1, keepdim=True) | |
| return correction + score / exp_factor | |
| def sample_categorical(probs: torch.Tensor) -> torch.Tensor: | |
| eps = 1e-10 | |
| gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + eps) + eps) | |
| return torch.argmax(torch.log(probs + eps) + gumbel_noise, dim=-1) | |
| # --- Global Model Loading --- | |
| print("Setting up model and device...") | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"===================================") | |
| print(f"Using device: {DEVICE}") | |
| print(f"===================================") | |
| model_args = dict(n_layer=6, n_head=6, n_embd=384, cond_dim=64, | |
| bias=False, vocab_size=vocab_size, block_size=CONTEXT_LENGTH, dropout=0.2) | |
| config = GPTConfig(**model_args) | |
| model = GPT(config) | |
| print("Loading pre-trained model weights...") | |
| model.load_state_dict( | |
| torch.hub.load_state_dict_from_url( | |
| 'https://raw.githubusercontent.com/ash80/diffusion-gpt/master/pretrained_model/model_epoch_25.pth', | |
| map_location=DEVICE | |
| ) | |
| ) | |
| model.to(DEVICE) | |
| model.eval() | |
| NOISE = GeometricNoise(sigma_min=1e-4, sigma_max=20) | |
| print("Model setup complete. Launching Gradio demo...") | |
| # --- Gradio Generation Function --- | |
| def generate_text(steps): | |
| steps = int(steps) | |
| eps = 1e-5 | |
| timesteps = torch.linspace(1, eps, steps + 1, device=DEVICE) | |
| step_size = (1 - eps) / steps | |
| x = torch.randint(0, vocab_size, (1, CONTEXT_LENGTH), device=DEVICE) | |
| initial_text = decode(x) | |
| yield f"Step 0/{steps} (Initial Noise):\n\n{wrap_text(initial_text)}" | |
| time.sleep(0.5) | |
| with torch.no_grad(): | |
| for i in range(steps): | |
| progress(i / steps, desc=f"Denoising Step {i+1}/{steps}") | |
| t = timesteps[i] * torch.ones(x.shape[0], 1, device=DEVICE) | |
| curr_sigma_bar, _ = NOISE(t) | |
| next_t = t - step_size | |
| next_sigma_bar, _ = NOISE(next_t) | |
| delta_sigma = curr_sigma_bar - next_sigma_bar | |
| log_score = model(x, curr_sigma_bar) | |
| score = torch.exp(log_score) | |
| stag_score = staggered_score(score, delta_sigma) | |
| probs = stag_score * transition(x, delta_sigma) | |
| x = sample_categorical(probs) | |
| decoded_text = decode(x) | |
| yield f"Step {i+1}/{steps}:\n\n{wrap_text(decoded_text)}" | |
| final_text = decode(x) | |
| yield f"Final Result (Step {steps}/{steps}):\n\n{wrap_text(final_text)}" | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # The Annotated Discrete Diffusion Model: Live Demo | |
| This demo visualizes the denoising process of a character-level discrete diffusion model. | |
| Start with pure random noise and watch as coherent text, in the style of Shakespeare, emerges over several steps. | |
| """ | |
| ) | |
| with gr.Row(): | |
| steps_slider = gr.Slider( | |
| minimum=10, | |
| maximum=200, | |
| value=128, | |
| step=1, | |
| label="Number of Denoising Steps", | |
| info="More steps can lead to better quality but take longer." | |
| ) | |
| generate_button = gr.Button("Generate", variant="primary") | |
| output_textbox = gr.Textbox( | |
| label="Denoising Process", | |
| lines=15, | |
| interactive=False, | |
| show_copy_button=True, | |
| placeholder="The denoising process will appear here..." | |
| ) | |
| generate_button.click( | |
| fn=generate_text, | |
| inputs=[steps_slider], | |
| outputs=[output_textbox] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |