Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -302,41 +302,61 @@ print("Model setup complete. Launching Gradio demo...")
|
|
| 302 |
# --- Gradio Generation Function ---
|
| 303 |
@spaces.GPU
|
| 304 |
def generate_text(steps):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
steps = int(steps)
|
| 306 |
eps = 1e-5
|
| 307 |
timesteps = torch.linspace(1, eps, steps + 1, device=DEVICE)
|
| 308 |
step_size = (1 - eps) / steps
|
| 309 |
|
|
|
|
| 310 |
x = torch.randint(0, vocab_size, (1, CONTEXT_LENGTH), device=DEVICE)
|
| 311 |
|
|
|
|
| 312 |
initial_text = decode(x)
|
| 313 |
yield f"Step 0/{steps} (Initial Noise):\n\n{wrap_text(initial_text)}"
|
| 314 |
time.sleep(0.5)
|
| 315 |
|
| 316 |
with torch.no_grad():
|
| 317 |
-
for i in range(steps):
|
|
|
|
| 318 |
t = timesteps[i] * torch.ones(x.shape[0], 1, device=DEVICE)
|
| 319 |
curr_sigma_bar, _ = NOISE(t)
|
| 320 |
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
decoded_text = decode(x)
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
|
| 338 |
# --- Gradio Interface ---
|
| 339 |
-
with gr.Blocks(theme=gr.themes.
|
| 340 |
gr.Markdown(
|
| 341 |
"""
|
| 342 |
# The Annotated Discrete Diffusion Model: Live Demo
|
|
|
|
| 302 |
# --- Gradio Generation Function ---
|
| 303 |
@spaces.GPU
|
| 304 |
def generate_text(steps):
|
| 305 |
+
"""
|
| 306 |
+
Generator function that yields denoised text at each step.
|
| 307 |
+
This logic is a 1:1 copy of the original Colab notebook's sampling loop.
|
| 308 |
+
"""
|
| 309 |
steps = int(steps)
|
| 310 |
eps = 1e-5
|
| 311 |
timesteps = torch.linspace(1, eps, steps + 1, device=DEVICE)
|
| 312 |
step_size = (1 - eps) / steps
|
| 313 |
|
| 314 |
+
# Start with a fresh random sample
|
| 315 |
x = torch.randint(0, vocab_size, (1, CONTEXT_LENGTH), device=DEVICE)
|
| 316 |
|
| 317 |
+
# Initial random text
|
| 318 |
initial_text = decode(x)
|
| 319 |
yield f"Step 0/{steps} (Initial Noise):\n\n{wrap_text(initial_text)}"
|
| 320 |
time.sleep(0.5)
|
| 321 |
|
| 322 |
with torch.no_grad():
|
| 323 |
+
for i in range(steps + 1):
|
| 324 |
+
|
| 325 |
t = timesteps[i] * torch.ones(x.shape[0], 1, device=DEVICE)
|
| 326 |
curr_sigma_bar, _ = NOISE(t)
|
| 327 |
|
| 328 |
+
if i < steps:
|
| 329 |
+
# This is an intermediate denoising step
|
| 330 |
+
next_sigma_bar, _ = NOISE(t - step_size)
|
| 331 |
+
delta_sigma = curr_sigma_bar - next_sigma_bar
|
| 332 |
+
|
| 333 |
+
log_score = model(x, curr_sigma_bar)
|
| 334 |
+
score = torch.exp(log_score)
|
| 335 |
+
stag_score = staggered_score(score, delta_sigma)
|
| 336 |
+
probs = stag_score * transition(x, delta_sigma)
|
| 337 |
+
x = sample_categorical(probs)
|
| 338 |
+
|
| 339 |
+
else:
|
| 340 |
+
# This is the final, full denoising step
|
| 341 |
+
# The "next sigma" is 0, so delta_sigma is the entire current noise.
|
| 342 |
+
delta_sigma = curr_sigma_bar
|
| 343 |
+
|
| 344 |
+
log_score = model(x, curr_sigma_bar)
|
| 345 |
+
score = torch.exp(log_score)
|
| 346 |
+
stag_score = staggered_score(score, delta_sigma)
|
| 347 |
+
probs = stag_score * transition(x, delta_sigma)
|
| 348 |
+
x = sample_categorical(probs)
|
| 349 |
+
|
| 350 |
+
# Yield the decoded text after each step
|
| 351 |
+
# The last yield will be the final result
|
| 352 |
decoded_text = decode(x)
|
| 353 |
+
if i < steps:
|
| 354 |
+
yield f"Step {i+1}/{steps}:\n\n{wrap_text(decoded_text)}"
|
| 355 |
+
else:
|
| 356 |
+
yield f"Final Result (Step {steps}/{steps}):\n\n{wrap_text(decoded_text)}"
|
| 357 |
|
| 358 |
# --- Gradio Interface ---
|
| 359 |
+
with gr.Blocks(theme=gr.themes.Citrus()) as demo:
|
| 360 |
gr.Markdown(
|
| 361 |
"""
|
| 362 |
# The Annotated Discrete Diffusion Model: Live Demo
|