multimodalart HF Staff commited on
Commit
93eee8b
·
verified ·
1 Parent(s): 7da55d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -17
app.py CHANGED
@@ -302,41 +302,61 @@ print("Model setup complete. Launching Gradio demo...")
302
  # --- Gradio Generation Function ---
303
  @spaces.GPU
304
  def generate_text(steps):
 
 
 
 
305
  steps = int(steps)
306
  eps = 1e-5
307
  timesteps = torch.linspace(1, eps, steps + 1, device=DEVICE)
308
  step_size = (1 - eps) / steps
309
 
 
310
  x = torch.randint(0, vocab_size, (1, CONTEXT_LENGTH), device=DEVICE)
311
 
 
312
  initial_text = decode(x)
313
  yield f"Step 0/{steps} (Initial Noise):\n\n{wrap_text(initial_text)}"
314
  time.sleep(0.5)
315
 
316
  with torch.no_grad():
317
- for i in range(steps):
 
318
  t = timesteps[i] * torch.ones(x.shape[0], 1, device=DEVICE)
319
  curr_sigma_bar, _ = NOISE(t)
320
 
321
- next_t = t - step_size
322
- next_sigma_bar, _ = NOISE(next_t)
323
- delta_sigma = curr_sigma_bar - next_sigma_bar
324
-
325
- log_score = model(x, curr_sigma_bar)
326
- score = torch.exp(log_score)
327
-
328
- stag_score = staggered_score(score, delta_sigma)
329
- probs = stag_score * transition(x, delta_sigma)
330
- x = sample_categorical(probs)
331
-
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  decoded_text = decode(x)
333
- yield f"Step {i+1}/{steps}:\n\n{wrap_text(decoded_text)}"
334
-
335
- final_text = decode(x)
336
- yield f"Final Result (Step {steps}/{steps}):\n\n{wrap_text(final_text)}"
337
 
338
  # --- Gradio Interface ---
339
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
340
  gr.Markdown(
341
  """
342
  # The Annotated Discrete Diffusion Model: Live Demo
 
302
  # --- Gradio Generation Function ---
303
  @spaces.GPU
304
  def generate_text(steps):
305
+ """
306
+ Generator function that yields denoised text at each step.
307
+ This logic is a 1:1 copy of the original Colab notebook's sampling loop.
308
+ """
309
  steps = int(steps)
310
  eps = 1e-5
311
  timesteps = torch.linspace(1, eps, steps + 1, device=DEVICE)
312
  step_size = (1 - eps) / steps
313
 
314
+ # Start with a fresh random sample
315
  x = torch.randint(0, vocab_size, (1, CONTEXT_LENGTH), device=DEVICE)
316
 
317
+ # Initial random text
318
  initial_text = decode(x)
319
  yield f"Step 0/{steps} (Initial Noise):\n\n{wrap_text(initial_text)}"
320
  time.sleep(0.5)
321
 
322
  with torch.no_grad():
323
+ for i in range(steps + 1):
324
+
325
  t = timesteps[i] * torch.ones(x.shape[0], 1, device=DEVICE)
326
  curr_sigma_bar, _ = NOISE(t)
327
 
328
+ if i < steps:
329
+ # This is an intermediate denoising step
330
+ next_sigma_bar, _ = NOISE(t - step_size)
331
+ delta_sigma = curr_sigma_bar - next_sigma_bar
332
+
333
+ log_score = model(x, curr_sigma_bar)
334
+ score = torch.exp(log_score)
335
+ stag_score = staggered_score(score, delta_sigma)
336
+ probs = stag_score * transition(x, delta_sigma)
337
+ x = sample_categorical(probs)
338
+
339
+ else:
340
+ # This is the final, full denoising step
341
+ # The "next sigma" is 0, so delta_sigma is the entire current noise.
342
+ delta_sigma = curr_sigma_bar
343
+
344
+ log_score = model(x, curr_sigma_bar)
345
+ score = torch.exp(log_score)
346
+ stag_score = staggered_score(score, delta_sigma)
347
+ probs = stag_score * transition(x, delta_sigma)
348
+ x = sample_categorical(probs)
349
+
350
+ # Yield the decoded text after each step
351
+ # The last yield will be the final result
352
  decoded_text = decode(x)
353
+ if i < steps:
354
+ yield f"Step {i+1}/{steps}:\n\n{wrap_text(decoded_text)}"
355
+ else:
356
+ yield f"Final Result (Step {steps}/{steps}):\n\n{wrap_text(decoded_text)}"
357
 
358
  # --- Gradio Interface ---
359
+ with gr.Blocks(theme=gr.themes.Citrus()) as demo:
360
  gr.Markdown(
361
  """
362
  # The Annotated Discrete Diffusion Model: Live Demo