multimodalart HF Staff commited on
Commit
a835137
·
verified ·
1 Parent(s): 5d89594

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -182
app.py CHANGED
@@ -1,96 +1,123 @@
1
- # app.py (Corrected Version)
2
-
3
- import os
 
 
 
4
  import math
 
5
  import pickle
6
- import shutil
7
- import subprocess
8
- import sys
9
  import textwrap
10
- import time
 
11
  from dataclasses import dataclass
12
  from typing import Optional
13
 
14
- import spaces
15
- import gradio as gr
16
- import numpy as np
17
- import torch
18
- import torch.nn as nn
19
- from torch.nn import functional as F
20
-
21
- # --- One-Time Setup Function ---
22
 
23
- def setup_data():
24
  """
25
- Checks for dataset metadata and prepares it if missing.
26
- This involves cloning a repo, running a script, and cleaning up.
 
 
 
27
  """
28
- data_dir = 'shakespeare_char'
29
- meta_path = os.path.join(data_dir, 'meta.pkl')
 
30
 
 
31
  if os.path.exists(meta_path):
32
- print("Dataset metadata found. Skipping setup.")
33
  return
34
 
35
- print("Dataset metadata not found. Starting one-time setup...")
36
- print("This may take a minute...")
37
-
38
- repo_url = "https://github.com/karpathy/nanoGPT"
39
- repo_dir = "nanoGPT"
40
-
41
- try:
42
- print(f"Cloning {repo_url}...")
43
- subprocess.run(["git", "clone", repo_url], check=True, capture_output=True)
44
-
45
- source_data_dir = os.path.join(repo_dir, 'data', 'shakespeare_char')
46
- print(f"Copying data from {source_data_dir} to {data_dir}...")
47
- shutil.copytree(source_data_dir, data_dir)
48
-
49
- prepare_script_path = os.path.join(data_dir, 'prepare.py')
50
- print(f"Running {prepare_script_path} to generate metadata...")
51
- subprocess.run([sys.executable, prepare_script_path], check=True, capture_output=True)
52
-
53
- print("Setup successful. 'meta.pkl' has been created.")
54
-
55
- except subprocess.CalledProcessError as e:
56
- print(f"An error occurred during setup: {e}", file=sys.stderr)
57
- print(f"Stdout: {e.stdout.decode()}", file=sys.stderr)
58
- print(f"Stderr: {e.stderr.decode()}", file=sys.stderr)
59
- sys.exit("Setup failed. Please check your git installation and internet connection.")
60
- except Exception as e:
61
- print(f"An unexpected error occurred: {e}", file=sys.stderr)
62
- sys.exit("Setup failed.")
63
- finally:
64
- if os.path.exists(repo_dir):
65
- print(f"Cleaning up by removing '{repo_dir}' directory...")
66
- shutil.rmtree(repo_dir)
67
-
68
- # --- Run Setup and Load Data ---
69
- setup_data()
70
-
71
- # Load metadata for character mappings
 
 
 
 
 
 
 
 
 
 
 
 
72
  data_dir = './shakespeare_char/'
73
  meta_path = os.path.join(data_dir, 'meta.pkl')
74
  with open(meta_path, 'rb') as f:
75
  meta = pickle.load(f)
76
 
 
77
  itos = meta['itos']
78
  stoi = meta['stoi']
79
- vocab_size = meta['vocab_size']
80
- CONTEXT_LENGTH = 256
81
 
82
  def decode(indices_tensor: torch.Tensor):
83
- if indices_tensor.dim() == 2:
84
- indices_tensor = indices_tensor[0]
 
85
  indices = indices_tensor.cpu().numpy()
86
- return ''.join([itos[i] for i in indices])
87
 
88
  def wrap_text(long_text, width=80):
 
89
  paragraphs = long_text.splitlines()
90
  wrapped = [textwrap.fill(p, width=width) if p else '' for p in paragraphs]
91
  return "\n".join(wrapped)
92
 
93
- # --- Model Architecture ---
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  class MLP(nn.Module):
96
  def __init__(self, config):
@@ -100,7 +127,11 @@ class MLP(nn.Module):
100
  self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
101
  self.dropout = nn.Dropout(config.dropout)
102
  def forward(self, x):
103
- return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
 
 
 
 
104
 
105
  class SelfAttention(nn.Module):
106
  def __init__(self, config):
@@ -121,21 +152,27 @@ class SelfAttention(nn.Module):
121
  q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
122
  v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
123
  if self.flash:
124
- y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False)
125
  else:
126
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
127
  att = F.softmax(att, dim=-1)
128
  att = self.attn_dropout(att)
129
  y = att @ v
130
  y = y.transpose(1, 2).contiguous().view(B, T, C)
131
- return self.resid_dropout(self.c_proj(y))
 
132
 
133
  def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
134
  return x * (1 + scale) + shift
135
 
136
  def bias_add_scale(x: torch.Tensor, bias: Optional[torch.Tensor], scale: torch.Tensor, residual: Optional[torch.Tensor]) -> torch.Tensor:
137
- out = scale * (x + bias) if bias is not None else scale * x
138
- return residual + out if residual is not None else out
 
 
 
 
 
139
 
140
  class DDiTBlock(nn.Module):
141
  def __init__(self, config):
@@ -144,15 +181,14 @@ class DDiTBlock(nn.Module):
144
  self.attn = SelfAttention(config)
145
  self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
146
  self.mlp = MLP(config)
147
-
148
  self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd)
149
  self.adaLN_modulation.weight.data.zero_()
150
  self.adaLN_modulation.bias.data.zero_()
151
-
152
  def forward(self, x, c):
153
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
154
- x_skip = x
155
- modulated_x = modulate(self.ln_1(x), shift_msa, scale_msa)
 
156
  x = bias_add_scale(self.attn(self.ln_1(x)), None, gate_msa, x_skip)
157
  x = bias_add_scale(self.mlp(modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x)
158
  return x
@@ -170,7 +206,8 @@ class DDitFinalLayer(nn.Module):
170
  def forward(self, x, c):
171
  shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
172
  x = modulate(self.norm_final(x), shift, scale)
173
- return self.linear(x)
 
174
 
175
  class TimestepEmbedder(nn.Module):
176
  def __init__(self, hidden_size, frequency_embedding_size=256):
@@ -184,7 +221,9 @@ class TimestepEmbedder(nn.Module):
184
  @staticmethod
185
  def timestep_embedding(t, dim, max_period=10000):
186
  half = dim // 2
187
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
 
 
188
  args = t[:, None].float() * freqs[None]
189
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
190
  if dim % 2:
@@ -192,11 +231,14 @@ class TimestepEmbedder(nn.Module):
192
  return embedding
193
  def forward(self, t):
194
  t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
195
- return self.mlp(t_freq)
 
196
 
197
  class GPT(nn.Module):
198
  def __init__(self, config):
199
  super().__init__()
 
 
200
  self.config = config
201
  self.sigma_map = TimestepEmbedder(config.cond_dim)
202
  self.transformer = nn.ModuleDict(dict(
@@ -204,16 +246,13 @@ class GPT(nn.Module):
204
  wpe = nn.Embedding(config.block_size, config.n_embd),
205
  drop = nn.Dropout(config.dropout),
206
  h = nn.ModuleList([DDiTBlock(config) for _ in range(config.n_layer)]),
207
- ln_f = nn.LayerNorm(config.n_embd, bias=config.bias), # <<< FIX 1: ADDED THIS LAYER
208
  ))
209
  self.lm_head = DDitFinalLayer(config)
210
  self.apply(self._init_weights)
211
-
212
- # Apply special scaled init to the residual projections
213
  for pn, p in self.named_parameters():
214
  if pn.endswith('c_proj.weight'):
215
  torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
216
-
217
  def _init_weights(self, module):
218
  if isinstance(module, nn.Linear):
219
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
@@ -221,48 +260,41 @@ class GPT(nn.Module):
221
  torch.nn.init.zeros_(module.bias)
222
  elif isinstance(module, nn.Embedding):
223
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
224
-
225
  def forward(self, idx, sigma):
226
  sigma = sigma.reshape(-1)
227
  b, t = idx.size()
228
  c = F.silu(self.sigma_map(sigma))
229
- pos = torch.arange(0, t, dtype=torch.long, device=idx.device)
 
230
  tok_emb = self.transformer.wte(idx)
231
  pos_emb = self.transformer.wpe(pos)
232
  x = self.transformer.drop(tok_emb + pos_emb)
233
  for block in self.transformer.h:
234
  x = block(x, c)
235
- x = self.transformer.ln_f(x) # <<< FIX 2: CALLED THE LAYER HERE
236
  x = self.lm_head(x, c)
237
- return torch.scatter(x, -1, idx[..., None], torch.zeros_like(x[..., :1]))
238
-
239
- @dataclass
240
- class GPTConfig:
241
- block_size: int = 1024
242
- vocab_size: int = 50304
243
- n_layer: int = 12
244
- n_head: int = 12
245
- n_embd: int = 768
246
- cond_dim: int = 64
247
- dropout: float = 0.0
248
- bias: bool = False
249
-
250
- # --- Noise Schedule & Sampling Logic ---
251
 
252
  class GeometricNoise:
253
  def __init__(self, sigma_min=1e-4, sigma_max=20):
254
- self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
 
 
255
  def total_noise(self, t):
256
  return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
257
  def __call__(self, t):
258
- return self.total_noise(t), None # Rate not needed for sampling
 
 
259
 
260
  def transition(x_t: torch.Tensor, delta_sigma: torch.Tensor) -> torch.Tensor:
261
  base_prob = (1 - torch.exp(-delta_sigma[..., None])) / vocab_size
262
  trans = torch.ones(*x_t.shape, vocab_size, device=x_t.device) * base_prob
263
  trans = trans.scatter(-1, x_t[..., None], torch.zeros_like(trans))
264
  diag_fill = 1 - trans.sum(dim=-1, keepdim=True)
265
- return trans.scatter(-1, x_t[..., None], diag_fill)
 
266
 
267
  def staggered_score(score, delta_sigma):
268
  exp_factor = torch.exp(-delta_sigma)[..., None]
@@ -274,119 +306,114 @@ def sample_categorical(probs: torch.Tensor) -> torch.Tensor:
274
  gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + eps) + eps)
275
  return torch.argmax(torch.log(probs + eps) + gumbel_noise, dim=-1)
276
 
277
- # --- Global Model Loading ---
278
 
279
- print("Setting up model and device...")
280
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
281
- print(f"===================================")
282
- print(f"Using device: {DEVICE}")
283
- print(f"===================================")
284
 
 
285
  model_args = dict(n_layer=6, n_head=6, n_embd=384, cond_dim=64,
286
- bias=False, vocab_size=vocab_size, block_size=CONTEXT_LENGTH, dropout=0.2)
287
  config = GPTConfig(**model_args)
288
  model = GPT(config)
289
 
290
- print("Loading pre-trained model weights...")
291
  model.load_state_dict(
292
  torch.hub.load_state_dict_from_url(
293
  'https://raw.githubusercontent.com/ash80/diffusion-gpt/master/pretrained_model/model_epoch_25.pth',
294
- map_location=DEVICE
295
  )
296
  )
297
- model.to(DEVICE)
298
  model.eval()
299
 
300
- NOISE = GeometricNoise(sigma_min=1e-4, sigma_max=20)
301
- print("Model setup complete. Launching Gradio demo...")
302
 
303
- # --- Gradio Generation Function ---
 
304
  @spaces.GPU
305
  def generate_text(steps):
306
  """
307
- Generator function that yields denoised text at each step.
308
- This logic is a 1:1 copy of the original Colab notebook's sampling loop.
309
  """
310
  steps = int(steps)
311
  eps = 1e-5
312
- timesteps = torch.linspace(1, eps, steps + 1, device=DEVICE)
313
- step_size = (1 - eps) / steps
314
-
315
- # Start with a fresh random sample
316
- x = torch.randint(0, vocab_size, (1, CONTEXT_LENGTH), device=DEVICE)
317
-
318
- # Initial random text
319
- initial_text = decode(x)
320
- yield f"Step 0/{steps} (Initial Noise):\n\n{wrap_text(initial_text)}"
321
- time.sleep(0.5)
322
 
 
 
 
323
  with torch.no_grad():
324
- for i in range(steps + 1):
325
-
326
- t = timesteps[i] * torch.ones(x.shape[0], 1, device=DEVICE)
327
- curr_sigma_bar, _ = NOISE(t)
328
-
329
- if i < steps:
330
- # This is an intermediate denoising step
331
- next_sigma_bar, _ = NOISE(t - step_size)
332
- delta_sigma = curr_sigma_bar - next_sigma_bar
333
-
334
- log_score = model(x, curr_sigma_bar)
335
- score = torch.exp(log_score)
336
- stag_score = staggered_score(score, delta_sigma)
337
- probs = stag_score * transition(x, delta_sigma)
338
- x = sample_categorical(probs)
339
-
340
- else:
341
- # This is the final, full denoising step
342
- # The "next sigma" is 0, so delta_sigma is the entire current noise.
343
- delta_sigma = curr_sigma_bar
344
-
345
- log_score = model(x, curr_sigma_bar)
346
- score = torch.exp(log_score)
347
- stag_score = staggered_score(score, delta_sigma)
348
- probs = stag_score * transition(x, delta_sigma)
349
- x = sample_categorical(probs)
350
-
351
- # Yield the decoded text after each step
352
- # The last yield will be the final result
353
- decoded_text = decode(x)
354
- if i < steps:
355
- yield f"Step {i+1}/{steps}:\n\n{wrap_text(decoded_text)}"
356
- else:
357
- yield f"Final Result (Step {steps}/{steps}):\n\n{wrap_text(decoded_text)}"
358
-
359
- # --- Gradio Interface ---
360
  with gr.Blocks(theme=gr.themes.Citrus()) as demo:
361
  gr.Markdown(
362
  """
363
- # The Annotated Discrete Diffusion Model: Live Demo
364
- This demo visualizes the denoising process of a character-level discrete diffusion model.
365
- Start with pure random noise and watch as coherent text, in the style of Shakespeare, emerges over several steps.
366
  """
367
  )
368
- with gr.Row():
369
- steps_slider = gr.Slider(
370
- minimum=10,
371
- maximum=200,
372
- value=128,
373
- step=1,
374
- label="Number of Denoising Steps",
375
- info="More steps can lead to better quality but take longer."
376
- )
377
- generate_button = gr.Button("Generate", variant="primary")
378
-
 
379
  output_textbox = gr.Textbox(
380
- label="Denoising Process",
381
- lines=15,
382
- interactive=False,
383
  show_copy_button=True,
384
- placeholder="The denoising process will appear here..."
385
  )
386
-
387
  generate_button.click(
388
- fn=generate_text,
389
- inputs=[steps_slider],
390
  outputs=[output_textbox]
391
  )
392
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
  import math
8
+ import os
9
  import pickle
10
+ import requests
 
 
11
  import textwrap
12
+ import subprocess
13
+ import shutil
14
  from dataclasses import dataclass
15
  from typing import Optional
16
 
 
 
 
 
 
 
 
 
17
 
18
+ def setup_environment():
19
  """
20
+ Checks for and sets up the necessary data and code.
21
+ - Clones nanoGPT if not present.
22
+ - Copies the shakespeare_char dataset directory.
23
+ - Runs the data preparation script to create meta.pkl and binary files.
24
+ This function makes the script self-contained.
25
  """
26
+ nano_gpt_repo_path = 'nanoGPT'
27
+ data_dir_path = 'shakespeare_char'
28
+ meta_path = os.path.join(data_dir_path, 'meta.pkl')
29
 
30
+ # If the final metadata file already exists, we assume setup is complete.
31
  if os.path.exists(meta_path):
32
+ print("Dataset and metadata found. Skipping setup.")
33
  return
34
 
35
+ print("Required data not found. Starting one-time setup...")
36
+
37
+ # 1. Clone nanoGPT repository if it doesn't exist
38
+ if not os.path.exists(nano_gpt_repo_path):
39
+ print(f"Cloning nanoGPT repository...")
40
+ try:
41
+ subprocess.run(
42
+ ['git', 'clone', 'https://github.com/karpathy/nanoGPT.git'],
43
+ check=True, capture_output=True, text=True
44
+ )
45
+ print("Cloned successfully.")
46
+ except subprocess.CalledProcessError as e:
47
+ print(f"Error cloning repository: {e.stderr}")
48
+ raise
49
+ else:
50
+ print("nanoGPT repository already exists.")
51
+
52
+ # 2. Copy the dataset directory if it doesn't exist
53
+ source_data_dir = os.path.join(nano_gpt_repo_path, 'data', 'shakespeare_char')
54
+ if not os.path.exists(data_dir_path):
55
+ print(f"Copying '{source_data_dir}' to '{data_dir_path}'...")
56
+ shutil.copytree(source_data_dir, data_dir_path)
57
+ print("Copied successfully.")
58
+ else:
59
+ print(f"'{data_dir_path}' directory already exists.")
60
+
61
+ # 3. Run the data preparation script
62
+ prepare_script_path = os.path.join(data_dir_path, 'prepare.py')
63
+ if not os.path.exists(meta_path):
64
+ print(f"Running data preparation script: '{prepare_script_path}'...")
65
+ # We need to run the script from within its directory for it to find input.txt
66
+ try:
67
+ subprocess.run(
68
+ ['python', 'prepare.py'],
69
+ check=True, cwd=data_dir_path, capture_output=True, text=True
70
+ )
71
+ print("Data preparation script finished successfully.")
72
+ except subprocess.CalledProcessError as e:
73
+ print(f"Error running prepare.py: {e.stderr}")
74
+ raise
75
+
76
+ print("Setup complete.")
77
+
78
+ # Run the setup process before anything else
79
+ setup_environment()
80
+
81
+ # --- 2. Global Setup & Helper Functions ---
82
+
83
+ # Load metadata (guaranteed to exist by the setup function)
84
  data_dir = './shakespeare_char/'
85
  meta_path = os.path.join(data_dir, 'meta.pkl')
86
  with open(meta_path, 'rb') as f:
87
  meta = pickle.load(f)
88
 
89
+ vocab_size = meta['vocab_size']
90
  itos = meta['itos']
91
  stoi = meta['stoi']
92
+ context_length = 256
93
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
94
 
95
  def decode(indices_tensor: torch.Tensor):
96
+ """Decodes a 1D tensor of indices to text"""
97
+ if indices_tensor.dim() > 1:
98
+ indices_tensor = indices_tensor.squeeze(0)
99
  indices = indices_tensor.cpu().numpy()
100
+ return ''.join([itos.get(i, '?') for i in indices]) # Use .get for safety
101
 
102
  def wrap_text(long_text, width=80):
103
+ """Wraps text to a maximum line width, preserving paragraph breaks."""
104
  paragraphs = long_text.splitlines()
105
  wrapped = [textwrap.fill(p, width=width) if p else '' for p in paragraphs]
106
  return "\n".join(wrapped)
107
 
108
+
109
+ # --- 3. Model Architecture (Identical to Notebook) ---
110
+
111
+ @dataclass
112
+ class GPTConfig:
113
+ block_size: int = 1024
114
+ vocab_size: int = 50304
115
+ n_layer: int = 12
116
+ n_head: int = 12
117
+ n_embd: int = 768
118
+ cond_dim: int = 64
119
+ dropout: float = 0.0
120
+ bias: bool = False
121
 
122
  class MLP(nn.Module):
123
  def __init__(self, config):
 
127
  self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
128
  self.dropout = nn.Dropout(config.dropout)
129
  def forward(self, x):
130
+ x = self.c_fc(x)
131
+ x = self.gelu(x)
132
+ x = self.c_proj(x)
133
+ x = self.dropout(x)
134
+ return x
135
 
136
  class SelfAttention(nn.Module):
137
  def __init__(self, config):
 
152
  q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
153
  v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
154
  if self.flash:
155
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False)
156
  else:
157
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
158
  att = F.softmax(att, dim=-1)
159
  att = self.attn_dropout(att)
160
  y = att @ v
161
  y = y.transpose(1, 2).contiguous().view(B, T, C)
162
+ y = self.resid_dropout(self.c_proj(y))
163
+ return y
164
 
165
  def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
166
  return x * (1 + scale) + shift
167
 
168
  def bias_add_scale(x: torch.Tensor, bias: Optional[torch.Tensor], scale: torch.Tensor, residual: Optional[torch.Tensor]) -> torch.Tensor:
169
+ if bias is not None:
170
+ out = scale * (x + bias)
171
+ else:
172
+ out = scale * x
173
+ if residual is not None:
174
+ out = residual + out
175
+ return out
176
 
177
  class DDiTBlock(nn.Module):
178
  def __init__(self, config):
 
181
  self.attn = SelfAttention(config)
182
  self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
183
  self.mlp = MLP(config)
 
184
  self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd)
185
  self.adaLN_modulation.weight.data.zero_()
186
  self.adaLN_modulation.bias.data.zero_()
 
187
  def forward(self, x, c):
188
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
189
+ x_skip = x
190
+ x = modulate(self.ln_1(x), shift_msa, scale_msa)
191
+ x = self.attn(x)
192
  x = bias_add_scale(self.attn(self.ln_1(x)), None, gate_msa, x_skip)
193
  x = bias_add_scale(self.mlp(modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x)
194
  return x
 
206
  def forward(self, x, c):
207
  shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
208
  x = modulate(self.norm_final(x), shift, scale)
209
+ x = self.linear(x)
210
+ return x
211
 
212
  class TimestepEmbedder(nn.Module):
213
  def __init__(self, hidden_size, frequency_embedding_size=256):
 
221
  @staticmethod
222
  def timestep_embedding(t, dim, max_period=10000):
223
  half = dim // 2
224
+ freqs = torch.exp(
225
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
226
+ ).to(device=t.device)
227
  args = t[:, None].float() * freqs[None]
228
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
229
  if dim % 2:
 
231
  return embedding
232
  def forward(self, t):
233
  t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
234
+ t_emb = self.mlp(t_freq)
235
+ return t_emb
236
 
237
  class GPT(nn.Module):
238
  def __init__(self, config):
239
  super().__init__()
240
+ assert config.vocab_size is not None
241
+ assert config.block_size is not None
242
  self.config = config
243
  self.sigma_map = TimestepEmbedder(config.cond_dim)
244
  self.transformer = nn.ModuleDict(dict(
 
246
  wpe = nn.Embedding(config.block_size, config.n_embd),
247
  drop = nn.Dropout(config.dropout),
248
  h = nn.ModuleList([DDiTBlock(config) for _ in range(config.n_layer)]),
249
+ ln_f = nn.LayerNorm(config.n_embd, bias=config.bias),
250
  ))
251
  self.lm_head = DDitFinalLayer(config)
252
  self.apply(self._init_weights)
 
 
253
  for pn, p in self.named_parameters():
254
  if pn.endswith('c_proj.weight'):
255
  torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
 
256
  def _init_weights(self, module):
257
  if isinstance(module, nn.Linear):
258
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
 
260
  torch.nn.init.zeros_(module.bias)
261
  elif isinstance(module, nn.Embedding):
262
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
 
263
  def forward(self, idx, sigma):
264
  sigma = sigma.reshape(-1)
265
  b, t = idx.size()
266
  c = F.silu(self.sigma_map(sigma))
267
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
268
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
269
  tok_emb = self.transformer.wte(idx)
270
  pos_emb = self.transformer.wpe(pos)
271
  x = self.transformer.drop(tok_emb + pos_emb)
272
  for block in self.transformer.h:
273
  x = block(x, c)
274
+ x = self.transformer.ln_f(x)
275
  x = self.lm_head(x, c)
276
+ x = torch.scatter(x, -1, idx[..., None], torch.zeros_like(x[..., :1]))
277
+ return x
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  class GeometricNoise:
280
  def __init__(self, sigma_min=1e-4, sigma_max=20):
281
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max]).to(device)
282
+ def rate_noise(self, t):
283
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (self.sigmas[1].log() - self.sigmas[0].log())
284
  def total_noise(self, t):
285
  return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
286
  def __call__(self, t):
287
+ return self.total_noise(t), self.rate_noise(t)
288
+
289
+ # --- 4. Inference & Sampling Logic (Identical to Notebook) ---
290
 
291
  def transition(x_t: torch.Tensor, delta_sigma: torch.Tensor) -> torch.Tensor:
292
  base_prob = (1 - torch.exp(-delta_sigma[..., None])) / vocab_size
293
  trans = torch.ones(*x_t.shape, vocab_size, device=x_t.device) * base_prob
294
  trans = trans.scatter(-1, x_t[..., None], torch.zeros_like(trans))
295
  diag_fill = 1 - trans.sum(dim=-1, keepdim=True)
296
+ trans = trans.scatter(-1, x_t[..., None], diag_fill)
297
+ return trans
298
 
299
  def staggered_score(score, delta_sigma):
300
  exp_factor = torch.exp(-delta_sigma)[..., None]
 
306
  gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + eps) + eps)
307
  return torch.argmax(torch.log(probs + eps) + gumbel_noise, dim=-1)
308
 
 
309
 
310
+ # --- 5. Model Initialization and Loading ---
 
 
 
 
311
 
312
+ print("Initializing and loading the pretrained model...")
313
  model_args = dict(n_layer=6, n_head=6, n_embd=384, cond_dim=64,
314
+ bias=False, vocab_size=vocab_size, block_size=context_length, dropout=0.2)
315
  config = GPTConfig(**model_args)
316
  model = GPT(config)
317
 
 
318
  model.load_state_dict(
319
  torch.hub.load_state_dict_from_url(
320
  'https://raw.githubusercontent.com/ash80/diffusion-gpt/master/pretrained_model/model_epoch_25.pth',
321
+ map_location=device
322
  )
323
  )
324
+ model.to(device)
325
  model.eval()
326
 
327
+ noise = GeometricNoise(sigma_min=1e-4, sigma_max=20)
328
+ print("Model loaded successfully.")
329
 
330
+
331
+ # --- 6. Gradio Interface ---
332
  @spaces.GPU
333
  def generate_text(steps):
334
  """
335
+ The main generation function for the Gradio app.
336
+ This function contains the exact denoising loop from the notebook.
337
  """
338
  steps = int(steps)
339
  eps = 1e-5
340
+
341
+ # Start with a random sample
342
+ x = torch.randint(0, vocab_size, (1, context_length), device=device)
343
+ initial_text = f"--- Initial Random Noise ---\n\n{wrap_text(decode(x[0]))}"
344
+ yield initial_text
 
 
 
 
 
345
 
346
+ timesteps = torch.linspace(1, eps, steps + 1, device=device)
347
+ step_size = (1 - eps) / steps
348
+
349
  with torch.no_grad():
350
+ for i in range(steps):
351
+ t = timesteps[i] * torch.ones(x.shape[0], 1, device=device)
352
+ curr_sigma_bar = noise(t)[0]
353
+
354
+ # This logic block handles all but the last step
355
+ next_sigma_bar = noise(t - step_size)[0]
356
+ delta_sigma = curr_sigma_bar - next_sigma_bar
357
+
358
+ log_score = model(x, curr_sigma_bar)
359
+ score = torch.exp(log_score)
360
+
361
+ stag_score = staggered_score(score, delta_sigma)
362
+ probs = stag_score * transition(x, delta_sigma)
363
+ x = sample_categorical(probs)
364
+
365
+ # Yield intermediate result
366
+ progress_text = f"--- Denoising Step {i + 1}/{steps} ---\n\n{wrap_text(decode(x[0]))}"
367
+ yield progress_text
368
+
369
+ # Final denoising step
370
+ t = timesteps[steps] * torch.ones(x.shape[0], 1, device=device)
371
+ curr_sigma_bar = noise(t)[0]
372
+ delta_sigma = curr_sigma_bar # delta is curr_sigma - 0
373
+
374
+ log_score = model(x, curr_sigma_bar)
375
+ score = torch.exp(log_score)
376
+
377
+ stag_score = staggered_score(score, delta_sigma)
378
+ probs = stag_score * transition(x, delta_sigma)
379
+ x = sample_categorical(probs)
380
+
381
+ final_text = f"--- Final Denoised Text (Step {steps}) ---\n\n{wrap_text(decode(x[0]))}"
382
+ yield final_text
383
+
384
+
385
+ # Define the Gradio UI
386
  with gr.Blocks(theme=gr.themes.Citrus()) as demo:
387
  gr.Markdown(
388
  """
389
+ # The Annotated Discrete Diffusion Models
390
+ This Gradio demo provides an interactive implementation of the character-level discrete diffusion model from the notebook.
391
+ The model starts with random characters (noise) and iteratively denoises them to generate coherent text in the style of Shakespeare.
392
  """
393
  )
394
+
395
+ steps_slider = gr.Slider(
396
+ minimum=10,
397
+ maximum=256,
398
+ value=128,
399
+ step=1,
400
+ label="Denoising Steps",
401
+ info="Number of steps in the reverse diffusion process."
402
+ )
403
+
404
+ generate_button = gr.Button("Generate", variant="primary")
405
+
406
  output_textbox = gr.Textbox(
407
+ label="Generated Text",
408
+ lines=15,
409
+ interactive=False,
410
  show_copy_button=True,
411
+ placeholder="Generation will appear here..."
412
  )
413
+
414
  generate_button.click(
415
+ fn=generate_text,
416
+ inputs=[steps_slider],
417
  outputs=[output_textbox]
418
  )
419