Manuela Larrea
Initial commit: GAN Interactive Demo
29099e4
"""
GAN Interactive Demo - Aplicación Gradio
Visualización interactiva del espacio latente y generación de dígitos MNIST
"""
import gradio as gr
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import plotly.graph_objects as go
import plotly.express as px
from PIL import Image
import io
import os
# Configuración
LATENT_DIM = 100
MODEL_DIR = "models"
# Cargar el generador
print("Cargando modelo generador...")
try:
generator = keras.models.load_model(f'{MODEL_DIR}/generator.h5', compile=False)
print("✓ Generador cargado exitosamente")
except Exception as e:
print(f"Error cargando generador: {e}")
generator = None
# Cargar vectores latentes pre-generados para exploración
try:
latent_vectors = np.load(f'{MODEL_DIR}/latent_vectors.npy')
generated_images_cache = np.load(f'{MODEL_DIR}/generated_images.npy')
print(f"✓ Vectores latentes cargados: {latent_vectors.shape}")
except Exception as e:
print(f"Generando nuevos vectores latentes...")
latent_vectors = np.random.normal(0, 1, (1000, LATENT_DIM))
if generator:
generated_images_cache = generator(latent_vectors, training=False).numpy()
else:
generated_images_cache = None
# Calcular reducción dimensional para visualización
print("Calculando reducción dimensional...")
pca = PCA(n_components=3)
latent_pca = pca.fit_transform(latent_vectors)
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
latent_tsne = tsne.fit_transform(latent_vectors[:500]) # Usar subset para velocidad
print("✓ Aplicación lista")
# ==================== FUNCIONES DE GENERACIÓN ====================
def generate_random_digit():
"""Genera un dígito aleatorio desde un vector latente random"""
if generator is None:
return None, "Modelo no disponible"
# Generar vector latente aleatorio
latent_vector = np.random.normal(0, 1, (1, LATENT_DIM))
# Generar imagen
generated_image = generator(latent_vector, training=False)
image = generated_image[0, :, :, 0].numpy()
# Desnormalizar
image = (image * 127.5 + 127.5).astype(np.uint8)
return image, f"Vector latente: {latent_vector[0, :5]}... (primeros 5 valores)"
def generate_from_sliders(*slider_values):
"""Genera un dígito desde valores de sliders (primeras 10 dimensiones)"""
if generator is None:
return None, "Modelo no disponible"
# Crear vector latente: primeras 10 dimensiones desde sliders, resto aleatorio
latent_vector = np.random.normal(0, 1, (1, LATENT_DIM))
latent_vector[0, :10] = slider_values
# Generar imagen
generated_image = generator(latent_vector, training=False)
image = generated_image[0, :, :, 0].numpy()
# Desnormalizar
image = (image * 127.5 + 127.5).astype(np.uint8)
return image
def interpolate_digits(start_seed, end_seed, steps):
"""Interpola entre dos dígitos generados desde semillas"""
if generator is None:
return None
# Generar vectores latentes desde semillas
np.random.seed(int(start_seed))
latent_start = np.random.normal(0, 1, (1, LATENT_DIM))
np.random.seed(int(end_seed))
latent_end = np.random.normal(0, 1, (1, LATENT_DIM))
# Crear interpolación lineal
alphas = np.linspace(0, 1, int(steps))
# Generar imágenes interpoladas
images = []
for alpha in alphas:
latent_interp = (1 - alpha) * latent_start + alpha * latent_end
generated = generator(latent_interp, training=False)
image = generated[0, :, :, 0].numpy()
image = (image * 127.5 + 127.5).astype(np.uint8)
images.append(image)
# Crear grid de imágenes
n_images = len(images)
cols = min(10, n_images)
rows = (n_images + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
if rows == 1:
axes = axes.reshape(1, -1)
for idx, image in enumerate(images):
row = idx // cols
col = idx % cols
axes[row, col].imshow(image, cmap='gray')
axes[row, col].axis('off')
axes[row, col].set_title(f'{idx+1}', fontsize=8)
# Ocultar ejes vacíos
for idx in range(n_images, rows * cols):
row = idx // cols
col = idx % cols
axes[row, col].axis('off')
plt.tight_layout()
# Convertir a imagen
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
plt.close()
return Image.open(buf)
def visualize_latent_space_pca():
"""Visualiza el espacio latente en 3D usando PCA"""
fig = go.Figure(data=[go.Scatter3d(
x=latent_pca[:, 0],
y=latent_pca[:, 1],
z=latent_pca[:, 2],
mode='markers',
marker=dict(
size=3,
color=latent_pca[:, 2],
colorscale='Viridis',
showscale=True,
colorbar=dict(title="PC3"),
opacity=0.7
),
text=[f'Punto {i}' for i in range(len(latent_pca))],
hovertemplate='<b>Punto %{text}</b><br>PC1: %{x:.2f}<br>PC2: %{y:.2f}<br>PC3: %{z:.2f}<extra></extra>'
)])
fig.update_layout(
title='Espacio Latente - Visualización PCA 3D',
scene=dict(
xaxis_title='Componente Principal 1',
yaxis_title='Componente Principal 2',
zaxis_title='Componente Principal 3',
bgcolor='rgba(240, 240, 240, 0.9)'
),
width=800,
height=600,
showlegend=False
)
return fig
def visualize_latent_space_tsne():
"""Visualiza el espacio latente en 2D usando t-SNE"""
fig = go.Figure(data=[go.Scatter(
x=latent_tsne[:, 0],
y=latent_tsne[:, 1],
mode='markers',
marker=dict(
size=6,
color=np.arange(len(latent_tsne)),
colorscale='Plasma',
showscale=True,
colorbar=dict(title="Índice"),
opacity=0.7
),
text=[f'Punto {i}' for i in range(len(latent_tsne))],
hovertemplate='<b>Punto %{text}</b><br>t-SNE 1: %{x:.2f}<br>t-SNE 2: %{y:.2f}<extra></extra>'
)])
fig.update_layout(
title='Espacio Latente - Visualización t-SNE 2D',
xaxis_title='Dimensión t-SNE 1',
yaxis_title='Dimensión t-SNE 2',
width=800,
height=600,
plot_bgcolor='rgba(240, 240, 240, 0.9)'
)
return fig
def generate_from_latent_index(index):
"""Genera imagen desde un índice del espacio latente pre-calculado"""
if generated_images_cache is None:
return None, "Cache no disponible"
index = int(index) % len(generated_images_cache)
image = generated_images_cache[index, :, :, 0]
image = (image * 127.5 + 127.5).astype(np.uint8)
return image, f"Índice: {index}\nVector latente: {latent_vectors[index, :5]}..."
def generate_grid_comparison():
"""Genera un grid de comparación de múltiples dígitos"""
if generator is None:
return None
# Generar 16 dígitos aleatorios
latent_vectors_batch = np.random.normal(0, 1, (16, LATENT_DIM))
generated_images = generator(latent_vectors_batch, training=False)
# Crear grid
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for i in range(4):
for j in range(4):
idx = i * 4 + j
image = generated_images[idx, :, :, 0].numpy()
image = (image * 127.5 + 127.5).astype(np.uint8)
axes[i, j].imshow(image, cmap='gray')
axes[i, j].axis('off')
plt.tight_layout()
# Convertir a imagen
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
plt.close()
return Image.open(buf)
# ==================== INTERFAZ GRADIO ====================
# CSS personalizado
custom_css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
.tab-nav button {
font-size: 16px;
font-weight: bold;
}
"""
# Crear interfaz
with gr.Blocks(css=custom_css, title="GAN Interactive Demo - MNIST", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎨 GAN Interactive Demo - Exploración del Espacio Latente
### Generative Adversarial Network entrenada en MNIST
Explora cómo una GAN aprende a generar dígitos manuscritos desde vectores de ruido aleatorio.
Inspirado en el TensorFlow Projector, esta demo te permite navegar el espacio latente de 100 dimensiones.
""")
with gr.Tabs():
# TAB 1: Generación Simple
with gr.Tab("🎲 Generación Aleatoria"):
gr.Markdown("### Genera dígitos aleatorios con un clic")
with gr.Row():
with gr.Column(scale=1):
btn_generate = gr.Button("🎲 Generar Dígito Aleatorio", variant="primary", size="lg")
latent_info = gr.Textbox(label="Información del Vector Latente", lines=2)
with gr.Column(scale=1):
output_image = gr.Image(label="Dígito Generado", type="numpy")
btn_generate.click(
fn=generate_random_digit,
outputs=[output_image, latent_info]
)
# TAB 2: Control Manual
with gr.Tab("🎛️ Control Manual"):
gr.Markdown("### Controla las primeras 10 dimensiones del vector latente")
gr.Markdown("Ajusta los sliders para ver cómo cada dimensión afecta la generación")
with gr.Row():
with gr.Column(scale=1):
sliders = []
for i in range(10):
slider = gr.Slider(
minimum=-3,
maximum=3,
value=0,
step=0.1,
label=f"Dimensión {i+1}"
)
sliders.append(slider)
btn_generate_sliders = gr.Button("Generar desde Sliders", variant="primary")
with gr.Column(scale=1):
output_image_sliders = gr.Image(label="Dígito Generado", type="numpy")
btn_generate_sliders.click(
fn=generate_from_sliders,
inputs=sliders,
outputs=output_image_sliders
)
# TAB 3: Interpolación
with gr.Tab("🔄 Interpolación"):
gr.Markdown("### Morphing entre dos dígitos")
gr.Markdown("Observa cómo la GAN transforma suavemente un dígito en otro")
with gr.Row():
with gr.Column(scale=1):
start_seed = gr.Number(label="Semilla Inicial", value=42)
end_seed = gr.Number(label="Semilla Final", value=123)
steps = gr.Slider(
minimum=5,
maximum=20,
value=10,
step=1,
label="Número de Pasos"
)
btn_interpolate = gr.Button("🔄 Generar Interpolación", variant="primary")
with gr.Column(scale=2):
output_interpolation = gr.Image(label="Secuencia de Interpolación")
btn_interpolate.click(
fn=interpolate_digits,
inputs=[start_seed, end_seed, steps],
outputs=output_interpolation
)
# TAB 4: Exploración del Espacio Latente
with gr.Tab("🌌 Espacio Latente"):
gr.Markdown("### Visualización del Espacio Latente de 100 Dimensiones")
gr.Markdown("Similar al TensorFlow Projector: explora cómo se distribuyen los vectores latentes")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### Visualización 3D (PCA)")
btn_pca = gr.Button("Mostrar PCA 3D", variant="secondary")
plot_pca = gr.Plot(label="Espacio Latente - PCA")
btn_pca.click(
fn=visualize_latent_space_pca,
outputs=plot_pca
)
with gr.Column(scale=1):
gr.Markdown("#### Visualización 2D (t-SNE)")
btn_tsne = gr.Button("Mostrar t-SNE 2D", variant="secondary")
plot_tsne = gr.Plot(label="Espacio Latente - t-SNE")
btn_tsne.click(
fn=visualize_latent_space_tsne,
outputs=plot_tsne
)
gr.Markdown("---")
gr.Markdown("#### Genera desde un punto específico del espacio")
with gr.Row():
with gr.Column(scale=1):
latent_index = gr.Slider(
minimum=0,
maximum=999,
value=0,
step=1,
label="Índice del Vector Latente"
)
btn_generate_index = gr.Button("Generar desde Índice", variant="primary")
latent_index_info = gr.Textbox(label="Información", lines=2)
with gr.Column(scale=1):
output_image_index = gr.Image(label="Dígito Generado", type="numpy")
btn_generate_index.click(
fn=generate_from_latent_index,
inputs=latent_index,
outputs=[output_image_index, latent_index_info]
)
# TAB 5: Grid de Comparación
with gr.Tab("📊 Grid de Dígitos"):
gr.Markdown("### Genera múltiples dígitos simultáneamente")
gr.Markdown("Observa la diversidad y calidad de las generaciones")
with gr.Row():
with gr.Column(scale=1):
btn_grid = gr.Button("🎨 Generar Grid 4×4", variant="primary", size="lg")
with gr.Column(scale=2):
output_grid = gr.Image(label="Grid de 16 Dígitos Generados")
btn_grid.click(
fn=generate_grid_comparison,
outputs=output_grid
)
gr.Markdown("""
---
### 📚 Sobre esta Demo
Esta aplicación interactiva demuestra el poder de las **Redes Generativas Adversarias (GANs)** entrenadas en el dataset MNIST.
**Características:**
- **Espacio Latente de 100 dimensiones**: Cada dígito es generado desde un vector de 100 números aleatorios
- **Visualización dimensional**: PCA y t-SNE reducen las 100 dimensiones a 2D/3D para visualización
- **Interpolación suave**: Demuestra que el espacio latente es continuo y significativo
- **Generación instantánea**: Sin necesidad de re-entrenar
**Arquitectura:**
- **Generador**: 7×7×256 → 14×14×64 → 28×28×1 (Conv2DTranspose + BatchNorm + LeakyReLU)
- **Discriminador**: 28×28×1 → 14×14×64 → 7×7×128 → Logit (Conv2D + Dropout)
- **Entrenamiento**: 50 épocas, Adam optimizer, Binary Cross-Entropy loss
🎓 **Creado para la clase de Machine Learning**
""")
# Lanzar aplicación
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)