File size: 15,878 Bytes
29099e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
"""
GAN Interactive Demo - Aplicación Gradio
Visualización interactiva del espacio latente y generación de dígitos MNIST
"""
import gradio as gr
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import plotly.graph_objects as go
import plotly.express as px
from PIL import Image
import io
import os

# Configuración
LATENT_DIM = 100
MODEL_DIR = "models"

# Cargar el generador
print("Cargando modelo generador...")
try:
    generator = keras.models.load_model(f'{MODEL_DIR}/generator.h5', compile=False)
    print("✓ Generador cargado exitosamente")
except Exception as e:
    print(f"Error cargando generador: {e}")
    generator = None

# Cargar vectores latentes pre-generados para exploración
try:
    latent_vectors = np.load(f'{MODEL_DIR}/latent_vectors.npy')
    generated_images_cache = np.load(f'{MODEL_DIR}/generated_images.npy')
    print(f"✓ Vectores latentes cargados: {latent_vectors.shape}")
except Exception as e:
    print(f"Generando nuevos vectores latentes...")
    latent_vectors = np.random.normal(0, 1, (1000, LATENT_DIM))
    if generator:
        generated_images_cache = generator(latent_vectors, training=False).numpy()
    else:
        generated_images_cache = None

# Calcular reducción dimensional para visualización
print("Calculando reducción dimensional...")
pca = PCA(n_components=3)
latent_pca = pca.fit_transform(latent_vectors)

tsne = TSNE(n_components=2, random_state=42, perplexity=30)
latent_tsne = tsne.fit_transform(latent_vectors[:500])  # Usar subset para velocidad

print("✓ Aplicación lista")

# ==================== FUNCIONES DE GENERACIÓN ====================

def generate_random_digit():
    """Genera un dígito aleatorio desde un vector latente random"""
    if generator is None:
        return None, "Modelo no disponible"
    
    # Generar vector latente aleatorio
    latent_vector = np.random.normal(0, 1, (1, LATENT_DIM))
    
    # Generar imagen
    generated_image = generator(latent_vector, training=False)
    image = generated_image[0, :, :, 0].numpy()
    
    # Desnormalizar
    image = (image * 127.5 + 127.5).astype(np.uint8)
    
    return image, f"Vector latente: {latent_vector[0, :5]}... (primeros 5 valores)"

def generate_from_sliders(*slider_values):
    """Genera un dígito desde valores de sliders (primeras 10 dimensiones)"""
    if generator is None:
        return None, "Modelo no disponible"
    
    # Crear vector latente: primeras 10 dimensiones desde sliders, resto aleatorio
    latent_vector = np.random.normal(0, 1, (1, LATENT_DIM))
    latent_vector[0, :10] = slider_values
    
    # Generar imagen
    generated_image = generator(latent_vector, training=False)
    image = generated_image[0, :, :, 0].numpy()
    
    # Desnormalizar
    image = (image * 127.5 + 127.5).astype(np.uint8)
    
    return image

def interpolate_digits(start_seed, end_seed, steps):
    """Interpola entre dos dígitos generados desde semillas"""
    if generator is None:
        return None
    
    # Generar vectores latentes desde semillas
    np.random.seed(int(start_seed))
    latent_start = np.random.normal(0, 1, (1, LATENT_DIM))
    
    np.random.seed(int(end_seed))
    latent_end = np.random.normal(0, 1, (1, LATENT_DIM))
    
    # Crear interpolación lineal
    alphas = np.linspace(0, 1, int(steps))
    
    # Generar imágenes interpoladas
    images = []
    for alpha in alphas:
        latent_interp = (1 - alpha) * latent_start + alpha * latent_end
        generated = generator(latent_interp, training=False)
        image = generated[0, :, :, 0].numpy()
        image = (image * 127.5 + 127.5).astype(np.uint8)
        images.append(image)
    
    # Crear grid de imágenes
    n_images = len(images)
    cols = min(10, n_images)
    rows = (n_images + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
    if rows == 1:
        axes = axes.reshape(1, -1)
    
    for idx, image in enumerate(images):
        row = idx // cols
        col = idx % cols
        axes[row, col].imshow(image, cmap='gray')
        axes[row, col].axis('off')
        axes[row, col].set_title(f'{idx+1}', fontsize=8)
    
    # Ocultar ejes vacíos
    for idx in range(n_images, rows * cols):
        row = idx // cols
        col = idx % cols
        axes[row, col].axis('off')
    
    plt.tight_layout()
    
    # Convertir a imagen
    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
    buf.seek(0)
    plt.close()
    
    return Image.open(buf)

def visualize_latent_space_pca():
    """Visualiza el espacio latente en 3D usando PCA"""
    fig = go.Figure(data=[go.Scatter3d(
        x=latent_pca[:, 0],
        y=latent_pca[:, 1],
        z=latent_pca[:, 2],
        mode='markers',
        marker=dict(
            size=3,
            color=latent_pca[:, 2],
            colorscale='Viridis',
            showscale=True,
            colorbar=dict(title="PC3"),
            opacity=0.7
        ),
        text=[f'Punto {i}' for i in range(len(latent_pca))],
        hovertemplate='<b>Punto %{text}</b><br>PC1: %{x:.2f}<br>PC2: %{y:.2f}<br>PC3: %{z:.2f}<extra></extra>'
    )])
    
    fig.update_layout(
        title='Espacio Latente - Visualización PCA 3D',
        scene=dict(
            xaxis_title='Componente Principal 1',
            yaxis_title='Componente Principal 2',
            zaxis_title='Componente Principal 3',
            bgcolor='rgba(240, 240, 240, 0.9)'
        ),
        width=800,
        height=600,
        showlegend=False
    )
    
    return fig

def visualize_latent_space_tsne():
    """Visualiza el espacio latente en 2D usando t-SNE"""
    fig = go.Figure(data=[go.Scatter(
        x=latent_tsne[:, 0],
        y=latent_tsne[:, 1],
        mode='markers',
        marker=dict(
            size=6,
            color=np.arange(len(latent_tsne)),
            colorscale='Plasma',
            showscale=True,
            colorbar=dict(title="Índice"),
            opacity=0.7
        ),
        text=[f'Punto {i}' for i in range(len(latent_tsne))],
        hovertemplate='<b>Punto %{text}</b><br>t-SNE 1: %{x:.2f}<br>t-SNE 2: %{y:.2f}<extra></extra>'
    )])
    
    fig.update_layout(
        title='Espacio Latente - Visualización t-SNE 2D',
        xaxis_title='Dimensión t-SNE 1',
        yaxis_title='Dimensión t-SNE 2',
        width=800,
        height=600,
        plot_bgcolor='rgba(240, 240, 240, 0.9)'
    )
    
    return fig

def generate_from_latent_index(index):
    """Genera imagen desde un índice del espacio latente pre-calculado"""
    if generated_images_cache is None:
        return None, "Cache no disponible"
    
    index = int(index) % len(generated_images_cache)
    image = generated_images_cache[index, :, :, 0]
    image = (image * 127.5 + 127.5).astype(np.uint8)
    
    return image, f"Índice: {index}\nVector latente: {latent_vectors[index, :5]}..."

def generate_grid_comparison():
    """Genera un grid de comparación de múltiples dígitos"""
    if generator is None:
        return None
    
    # Generar 16 dígitos aleatorios
    latent_vectors_batch = np.random.normal(0, 1, (16, LATENT_DIM))
    generated_images = generator(latent_vectors_batch, training=False)
    
    # Crear grid
    fig, axes = plt.subplots(4, 4, figsize=(10, 10))
    
    for i in range(4):
        for j in range(4):
            idx = i * 4 + j
            image = generated_images[idx, :, :, 0].numpy()
            image = (image * 127.5 + 127.5).astype(np.uint8)
            
            axes[i, j].imshow(image, cmap='gray')
            axes[i, j].axis('off')
    
    plt.tight_layout()
    
    # Convertir a imagen
    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
    buf.seek(0)
    plt.close()
    
    return Image.open(buf)

# ==================== INTERFAZ GRADIO ====================

# CSS personalizado
custom_css = """
.gradio-container {
    font-family: 'Arial', sans-serif;
}
.tab-nav button {
    font-size: 16px;
    font-weight: bold;
}
"""

# Crear interfaz
with gr.Blocks(css=custom_css, title="GAN Interactive Demo - MNIST", theme=gr.themes.Soft()) as demo:
    
    gr.Markdown("""
    # 🎨 GAN Interactive Demo - Exploración del Espacio Latente
    
    ### Generative Adversarial Network entrenada en MNIST
    
    Explora cómo una GAN aprende a generar dígitos manuscritos desde vectores de ruido aleatorio.
    Inspirado en el TensorFlow Projector, esta demo te permite navegar el espacio latente de 100 dimensiones.
    """)
    
    with gr.Tabs():
        
        # TAB 1: Generación Simple
        with gr.Tab("🎲 Generación Aleatoria"):
            gr.Markdown("### Genera dígitos aleatorios con un clic")
            
            with gr.Row():
                with gr.Column(scale=1):
                    btn_generate = gr.Button("🎲 Generar Dígito Aleatorio", variant="primary", size="lg")
                    latent_info = gr.Textbox(label="Información del Vector Latente", lines=2)
                
                with gr.Column(scale=1):
                    output_image = gr.Image(label="Dígito Generado", type="numpy")
            
            btn_generate.click(
                fn=generate_random_digit,
                outputs=[output_image, latent_info]
            )
        
        # TAB 2: Control Manual
        with gr.Tab("🎛️ Control Manual"):
            gr.Markdown("### Controla las primeras 10 dimensiones del vector latente")
            gr.Markdown("Ajusta los sliders para ver cómo cada dimensión afecta la generación")
            
            with gr.Row():
                with gr.Column(scale=1):
                    sliders = []
                    for i in range(10):
                        slider = gr.Slider(
                            minimum=-3,
                            maximum=3,
                            value=0,
                            step=0.1,
                            label=f"Dimensión {i+1}"
                        )
                        sliders.append(slider)
                    
                    btn_generate_sliders = gr.Button("Generar desde Sliders", variant="primary")
                
                with gr.Column(scale=1):
                    output_image_sliders = gr.Image(label="Dígito Generado", type="numpy")
            
            btn_generate_sliders.click(
                fn=generate_from_sliders,
                inputs=sliders,
                outputs=output_image_sliders
            )
        
        # TAB 3: Interpolación
        with gr.Tab("🔄 Interpolación"):
            gr.Markdown("### Morphing entre dos dígitos")
            gr.Markdown("Observa cómo la GAN transforma suavemente un dígito en otro")
            
            with gr.Row():
                with gr.Column(scale=1):
                    start_seed = gr.Number(label="Semilla Inicial", value=42)
                    end_seed = gr.Number(label="Semilla Final", value=123)
                    steps = gr.Slider(
                        minimum=5,
                        maximum=20,
                        value=10,
                        step=1,
                        label="Número de Pasos"
                    )
                    btn_interpolate = gr.Button("🔄 Generar Interpolación", variant="primary")
                
                with gr.Column(scale=2):
                    output_interpolation = gr.Image(label="Secuencia de Interpolación")
            
            btn_interpolate.click(
                fn=interpolate_digits,
                inputs=[start_seed, end_seed, steps],
                outputs=output_interpolation
            )
        
        # TAB 4: Exploración del Espacio Latente
        with gr.Tab("🌌 Espacio Latente"):
            gr.Markdown("### Visualización del Espacio Latente de 100 Dimensiones")
            gr.Markdown("Similar al TensorFlow Projector: explora cómo se distribuyen los vectores latentes")
            
            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("#### Visualización 3D (PCA)")
                    btn_pca = gr.Button("Mostrar PCA 3D", variant="secondary")
                    plot_pca = gr.Plot(label="Espacio Latente - PCA")
                    
                    btn_pca.click(
                        fn=visualize_latent_space_pca,
                        outputs=plot_pca
                    )
                
                with gr.Column(scale=1):
                    gr.Markdown("#### Visualización 2D (t-SNE)")
                    btn_tsne = gr.Button("Mostrar t-SNE 2D", variant="secondary")
                    plot_tsne = gr.Plot(label="Espacio Latente - t-SNE")
                    
                    btn_tsne.click(
                        fn=visualize_latent_space_tsne,
                        outputs=plot_tsne
                    )
            
            gr.Markdown("---")
            gr.Markdown("#### Genera desde un punto específico del espacio")
            
            with gr.Row():
                with gr.Column(scale=1):
                    latent_index = gr.Slider(
                        minimum=0,
                        maximum=999,
                        value=0,
                        step=1,
                        label="Índice del Vector Latente"
                    )
                    btn_generate_index = gr.Button("Generar desde Índice", variant="primary")
                    latent_index_info = gr.Textbox(label="Información", lines=2)
                
                with gr.Column(scale=1):
                    output_image_index = gr.Image(label="Dígito Generado", type="numpy")
            
            btn_generate_index.click(
                fn=generate_from_latent_index,
                inputs=latent_index,
                outputs=[output_image_index, latent_index_info]
            )
        
        # TAB 5: Grid de Comparación
        with gr.Tab("📊 Grid de Dígitos"):
            gr.Markdown("### Genera múltiples dígitos simultáneamente")
            gr.Markdown("Observa la diversidad y calidad de las generaciones")
            
            with gr.Row():
                with gr.Column(scale=1):
                    btn_grid = gr.Button("🎨 Generar Grid 4×4", variant="primary", size="lg")
                
                with gr.Column(scale=2):
                    output_grid = gr.Image(label="Grid de 16 Dígitos Generados")
            
            btn_grid.click(
                fn=generate_grid_comparison,
                outputs=output_grid
            )
    
    gr.Markdown("""
    ---
    ### 📚 Sobre esta Demo
    
    Esta aplicación interactiva demuestra el poder de las **Redes Generativas Adversarias (GANs)** entrenadas en el dataset MNIST.
    
    **Características:**
    - **Espacio Latente de 100 dimensiones**: Cada dígito es generado desde un vector de 100 números aleatorios
    - **Visualización dimensional**: PCA y t-SNE reducen las 100 dimensiones a 2D/3D para visualización
    - **Interpolación suave**: Demuestra que el espacio latente es continuo y significativo
    - **Generación instantánea**: Sin necesidad de re-entrenar
    
    **Arquitectura:**
    - **Generador**: 7×7×256 → 14×14×64 → 28×28×1 (Conv2DTranspose + BatchNorm + LeakyReLU)
    - **Discriminador**: 28×28×1 → 14×14×64 → 7×7×128 → Logit (Conv2D + Dropout)
    - **Entrenamiento**: 50 épocas, Adam optimizer, Binary Cross-Entropy loss
    
    🎓 **Creado para la clase de Machine Learning**
    """)

# Lanzar aplicación
if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )