Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import math
|
| 3 |
import pickle
|
|
@@ -8,8 +10,8 @@ import textwrap
|
|
| 8 |
import time
|
| 9 |
from dataclasses import dataclass
|
| 10 |
from typing import Optional
|
| 11 |
-
import spaces
|
| 12 |
|
|
|
|
| 13 |
import gradio as gr
|
| 14 |
import numpy as np
|
| 15 |
import torch
|
|
@@ -37,19 +39,15 @@ def setup_data():
|
|
| 37 |
repo_dir = "nanoGPT"
|
| 38 |
|
| 39 |
try:
|
| 40 |
-
# 1. Clone the repository
|
| 41 |
print(f"Cloning {repo_url}...")
|
| 42 |
subprocess.run(["git", "clone", repo_url], check=True, capture_output=True)
|
| 43 |
|
| 44 |
-
# 2. Copy the data directory
|
| 45 |
source_data_dir = os.path.join(repo_dir, 'data', 'shakespeare_char')
|
| 46 |
print(f"Copying data from {source_data_dir} to {data_dir}...")
|
| 47 |
shutil.copytree(source_data_dir, data_dir)
|
| 48 |
|
| 49 |
-
# 3. Run the preparation script
|
| 50 |
prepare_script_path = os.path.join(data_dir, 'prepare.py')
|
| 51 |
print(f"Running {prepare_script_path} to generate metadata...")
|
| 52 |
-
# Use the same python executable that is running this script
|
| 53 |
subprocess.run([sys.executable, prepare_script_path], check=True, capture_output=True)
|
| 54 |
|
| 55 |
print("Setup successful. 'meta.pkl' has been created.")
|
|
@@ -63,7 +61,6 @@ def setup_data():
|
|
| 63 |
print(f"An unexpected error occurred: {e}", file=sys.stderr)
|
| 64 |
sys.exit("Setup failed.")
|
| 65 |
finally:
|
| 66 |
-
# 4. Clean up the cloned repository
|
| 67 |
if os.path.exists(repo_dir):
|
| 68 |
print(f"Cleaning up by removing '{repo_dir}' directory...")
|
| 69 |
shutil.rmtree(repo_dir)
|
|
@@ -83,20 +80,17 @@ vocab_size = meta['vocab_size']
|
|
| 83 |
CONTEXT_LENGTH = 256
|
| 84 |
|
| 85 |
def decode(indices_tensor: torch.Tensor):
|
| 86 |
-
'''Decodes a 1D tensor of indices to text'''
|
| 87 |
if indices_tensor.dim() == 2:
|
| 88 |
indices_tensor = indices_tensor[0]
|
| 89 |
indices = indices_tensor.cpu().numpy()
|
| 90 |
return ''.join([itos[i] for i in indices])
|
| 91 |
|
| 92 |
def wrap_text(long_text, width=80):
|
| 93 |
-
"""Wraps text to a maximum line width, preserving paragraph breaks."""
|
| 94 |
paragraphs = long_text.splitlines()
|
| 95 |
wrapped = [textwrap.fill(p, width=width) if p else '' for p in paragraphs]
|
| 96 |
return "\n".join(wrapped)
|
| 97 |
|
| 98 |
-
|
| 99 |
-
# --- Model Architecture (Copied from the notebook) ---
|
| 100 |
|
| 101 |
class MLP(nn.Module):
|
| 102 |
def __init__(self, config):
|
|
@@ -209,9 +203,16 @@ class GPT(nn.Module):
|
|
| 209 |
wpe = nn.Embedding(config.block_size, config.n_embd),
|
| 210 |
drop = nn.Dropout(config.dropout),
|
| 211 |
h = nn.ModuleList([DDiTBlock(config) for _ in range(config.n_layer)]),
|
|
|
|
| 212 |
))
|
| 213 |
self.lm_head = DDitFinalLayer(config)
|
| 214 |
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
def _init_weights(self, module):
|
| 216 |
if isinstance(module, nn.Linear):
|
| 217 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
@@ -219,6 +220,7 @@ class GPT(nn.Module):
|
|
| 219 |
torch.nn.init.zeros_(module.bias)
|
| 220 |
elif isinstance(module, nn.Embedding):
|
| 221 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
|
| 222 |
def forward(self, idx, sigma):
|
| 223 |
sigma = sigma.reshape(-1)
|
| 224 |
b, t = idx.size()
|
|
@@ -229,6 +231,7 @@ class GPT(nn.Module):
|
|
| 229 |
x = self.transformer.drop(tok_emb + pos_emb)
|
| 230 |
for block in self.transformer.h:
|
| 231 |
x = block(x, c)
|
|
|
|
| 232 |
x = self.lm_head(x, c)
|
| 233 |
return torch.scatter(x, -1, idx[..., None], torch.zeros_like(x[..., :1]))
|
| 234 |
|
|
@@ -248,12 +251,10 @@ class GPTConfig:
|
|
| 248 |
class GeometricNoise:
|
| 249 |
def __init__(self, sigma_min=1e-4, sigma_max=20):
|
| 250 |
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
|
| 251 |
-
def rate_noise(self, t):
|
| 252 |
-
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (self.sigmas[1].log() - self.sigmas[0].log())
|
| 253 |
def total_noise(self, t):
|
| 254 |
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
|
| 255 |
def __call__(self, t):
|
| 256 |
-
return self.total_noise(t),
|
| 257 |
|
| 258 |
def transition(x_t: torch.Tensor, delta_sigma: torch.Tensor) -> torch.Tensor:
|
| 259 |
base_prob = (1 - torch.exp(-delta_sigma[..., None])) / vocab_size
|
|
@@ -276,6 +277,10 @@ def sample_categorical(probs: torch.Tensor) -> torch.Tensor:
|
|
| 276 |
|
| 277 |
print("Setting up model and device...")
|
| 278 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
model_args = dict(n_layer=6, n_head=6, n_embd=384, cond_dim=64,
|
| 280 |
bias=False, vocab_size=vocab_size, block_size=CONTEXT_LENGTH, dropout=0.2)
|
| 281 |
config = GPTConfig(**model_args)
|
|
@@ -295,19 +300,15 @@ NOISE = GeometricNoise(sigma_min=1e-4, sigma_max=20)
|
|
| 295 |
print("Model setup complete. Launching Gradio demo...")
|
| 296 |
|
| 297 |
# --- Gradio Generation Function ---
|
| 298 |
-
|
| 299 |
@spaces.GPU
|
| 300 |
def generate_text(steps):
|
| 301 |
-
"""Generator function that yields denoised text at each step."""
|
| 302 |
steps = int(steps)
|
| 303 |
eps = 1e-5
|
| 304 |
timesteps = torch.linspace(1, eps, steps + 1, device=DEVICE)
|
| 305 |
step_size = (1 - eps) / steps
|
| 306 |
|
| 307 |
-
# Start with a fresh random sample
|
| 308 |
x = torch.randint(0, vocab_size, (1, CONTEXT_LENGTH), device=DEVICE)
|
| 309 |
|
| 310 |
-
# Initial random text
|
| 311 |
initial_text = decode(x)
|
| 312 |
yield f"Step 0/{steps} (Initial Noise):\n\n{wrap_text(initial_text)}"
|
| 313 |
time.sleep(0.5)
|
|
@@ -317,9 +318,10 @@ def generate_text(steps):
|
|
| 317 |
progress(i / steps, desc=f"Denoising Step {i+1}/{steps}")
|
| 318 |
|
| 319 |
t = timesteps[i] * torch.ones(x.shape[0], 1, device=DEVICE)
|
| 320 |
-
curr_sigma_bar = NOISE(t)
|
| 321 |
|
| 322 |
-
|
|
|
|
| 323 |
delta_sigma = curr_sigma_bar - next_sigma_bar
|
| 324 |
|
| 325 |
log_score = model(x, curr_sigma_bar)
|
|
@@ -329,11 +331,9 @@ def generate_text(steps):
|
|
| 329 |
probs = stag_score * transition(x, delta_sigma)
|
| 330 |
x = sample_categorical(probs)
|
| 331 |
|
| 332 |
-
# Yield the decoded text and step info
|
| 333 |
decoded_text = decode(x)
|
| 334 |
yield f"Step {i+1}/{steps}:\n\n{wrap_text(decoded_text)}"
|
| 335 |
|
| 336 |
-
# Final result
|
| 337 |
final_text = decode(x)
|
| 338 |
yield f"Final Result (Step {steps}/{steps}):\n\n{wrap_text(final_text)}"
|
| 339 |
|
|
|
|
| 1 |
+
# app.py (Corrected Version)
|
| 2 |
+
|
| 3 |
import os
|
| 4 |
import math
|
| 5 |
import pickle
|
|
|
|
| 10 |
import time
|
| 11 |
from dataclasses import dataclass
|
| 12 |
from typing import Optional
|
|
|
|
| 13 |
|
| 14 |
+
import spaces
|
| 15 |
import gradio as gr
|
| 16 |
import numpy as np
|
| 17 |
import torch
|
|
|
|
| 39 |
repo_dir = "nanoGPT"
|
| 40 |
|
| 41 |
try:
|
|
|
|
| 42 |
print(f"Cloning {repo_url}...")
|
| 43 |
subprocess.run(["git", "clone", repo_url], check=True, capture_output=True)
|
| 44 |
|
|
|
|
| 45 |
source_data_dir = os.path.join(repo_dir, 'data', 'shakespeare_char')
|
| 46 |
print(f"Copying data from {source_data_dir} to {data_dir}...")
|
| 47 |
shutil.copytree(source_data_dir, data_dir)
|
| 48 |
|
|
|
|
| 49 |
prepare_script_path = os.path.join(data_dir, 'prepare.py')
|
| 50 |
print(f"Running {prepare_script_path} to generate metadata...")
|
|
|
|
| 51 |
subprocess.run([sys.executable, prepare_script_path], check=True, capture_output=True)
|
| 52 |
|
| 53 |
print("Setup successful. 'meta.pkl' has been created.")
|
|
|
|
| 61 |
print(f"An unexpected error occurred: {e}", file=sys.stderr)
|
| 62 |
sys.exit("Setup failed.")
|
| 63 |
finally:
|
|
|
|
| 64 |
if os.path.exists(repo_dir):
|
| 65 |
print(f"Cleaning up by removing '{repo_dir}' directory...")
|
| 66 |
shutil.rmtree(repo_dir)
|
|
|
|
| 80 |
CONTEXT_LENGTH = 256
|
| 81 |
|
| 82 |
def decode(indices_tensor: torch.Tensor):
|
|
|
|
| 83 |
if indices_tensor.dim() == 2:
|
| 84 |
indices_tensor = indices_tensor[0]
|
| 85 |
indices = indices_tensor.cpu().numpy()
|
| 86 |
return ''.join([itos[i] for i in indices])
|
| 87 |
|
| 88 |
def wrap_text(long_text, width=80):
|
|
|
|
| 89 |
paragraphs = long_text.splitlines()
|
| 90 |
wrapped = [textwrap.fill(p, width=width) if p else '' for p in paragraphs]
|
| 91 |
return "\n".join(wrapped)
|
| 92 |
|
| 93 |
+
# --- Model Architecture ---
|
|
|
|
| 94 |
|
| 95 |
class MLP(nn.Module):
|
| 96 |
def __init__(self, config):
|
|
|
|
| 203 |
wpe = nn.Embedding(config.block_size, config.n_embd),
|
| 204 |
drop = nn.Dropout(config.dropout),
|
| 205 |
h = nn.ModuleList([DDiTBlock(config) for _ in range(config.n_layer)]),
|
| 206 |
+
ln_f = nn.LayerNorm(config.n_embd, bias=config.bias), # <<< FIX 1: ADDED THIS LAYER
|
| 207 |
))
|
| 208 |
self.lm_head = DDitFinalLayer(config)
|
| 209 |
self.apply(self._init_weights)
|
| 210 |
+
|
| 211 |
+
# Apply special scaled init to the residual projections
|
| 212 |
+
for pn, p in self.named_parameters():
|
| 213 |
+
if pn.endswith('c_proj.weight'):
|
| 214 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
| 215 |
+
|
| 216 |
def _init_weights(self, module):
|
| 217 |
if isinstance(module, nn.Linear):
|
| 218 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
|
| 220 |
torch.nn.init.zeros_(module.bias)
|
| 221 |
elif isinstance(module, nn.Embedding):
|
| 222 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 223 |
+
|
| 224 |
def forward(self, idx, sigma):
|
| 225 |
sigma = sigma.reshape(-1)
|
| 226 |
b, t = idx.size()
|
|
|
|
| 231 |
x = self.transformer.drop(tok_emb + pos_emb)
|
| 232 |
for block in self.transformer.h:
|
| 233 |
x = block(x, c)
|
| 234 |
+
x = self.transformer.ln_f(x) # <<< FIX 2: CALLED THE LAYER HERE
|
| 235 |
x = self.lm_head(x, c)
|
| 236 |
return torch.scatter(x, -1, idx[..., None], torch.zeros_like(x[..., :1]))
|
| 237 |
|
|
|
|
| 251 |
class GeometricNoise:
|
| 252 |
def __init__(self, sigma_min=1e-4, sigma_max=20):
|
| 253 |
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
|
|
|
|
|
|
|
| 254 |
def total_noise(self, t):
|
| 255 |
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
|
| 256 |
def __call__(self, t):
|
| 257 |
+
return self.total_noise(t), None # Rate not needed for sampling
|
| 258 |
|
| 259 |
def transition(x_t: torch.Tensor, delta_sigma: torch.Tensor) -> torch.Tensor:
|
| 260 |
base_prob = (1 - torch.exp(-delta_sigma[..., None])) / vocab_size
|
|
|
|
| 277 |
|
| 278 |
print("Setting up model and device...")
|
| 279 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 280 |
+
print(f"===================================")
|
| 281 |
+
print(f"Using device: {DEVICE}")
|
| 282 |
+
print(f"===================================")
|
| 283 |
+
|
| 284 |
model_args = dict(n_layer=6, n_head=6, n_embd=384, cond_dim=64,
|
| 285 |
bias=False, vocab_size=vocab_size, block_size=CONTEXT_LENGTH, dropout=0.2)
|
| 286 |
config = GPTConfig(**model_args)
|
|
|
|
| 300 |
print("Model setup complete. Launching Gradio demo...")
|
| 301 |
|
| 302 |
# --- Gradio Generation Function ---
|
|
|
|
| 303 |
@spaces.GPU
|
| 304 |
def generate_text(steps):
|
|
|
|
| 305 |
steps = int(steps)
|
| 306 |
eps = 1e-5
|
| 307 |
timesteps = torch.linspace(1, eps, steps + 1, device=DEVICE)
|
| 308 |
step_size = (1 - eps) / steps
|
| 309 |
|
|
|
|
| 310 |
x = torch.randint(0, vocab_size, (1, CONTEXT_LENGTH), device=DEVICE)
|
| 311 |
|
|
|
|
| 312 |
initial_text = decode(x)
|
| 313 |
yield f"Step 0/{steps} (Initial Noise):\n\n{wrap_text(initial_text)}"
|
| 314 |
time.sleep(0.5)
|
|
|
|
| 318 |
progress(i / steps, desc=f"Denoising Step {i+1}/{steps}")
|
| 319 |
|
| 320 |
t = timesteps[i] * torch.ones(x.shape[0], 1, device=DEVICE)
|
| 321 |
+
curr_sigma_bar, _ = NOISE(t)
|
| 322 |
|
| 323 |
+
next_t = t - step_size
|
| 324 |
+
next_sigma_bar, _ = NOISE(next_t)
|
| 325 |
delta_sigma = curr_sigma_bar - next_sigma_bar
|
| 326 |
|
| 327 |
log_score = model(x, curr_sigma_bar)
|
|
|
|
| 331 |
probs = stag_score * transition(x, delta_sigma)
|
| 332 |
x = sample_categorical(probs)
|
| 333 |
|
|
|
|
| 334 |
decoded_text = decode(x)
|
| 335 |
yield f"Step {i+1}/{steps}:\n\n{wrap_text(decoded_text)}"
|
| 336 |
|
|
|
|
| 337 |
final_text = decode(x)
|
| 338 |
yield f"Final Result (Step {steps}/{steps}):\n\n{wrap_text(final_text)}"
|
| 339 |
|