Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import io | |
| import gc | |
| import math | |
| import time | |
| import uuid | |
| import json | |
| import spaces | |
| import random | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass, field, asdict | |
| from typing import Dict, List, Tuple, Optional, Any, Union | |
| from enum import Enum | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| import mido | |
| from mido import Message, MidiFile, MidiTrack | |
| # Configuration Classes | |
| class ComputeMode(Enum): | |
| """Enum for computation modes.""" | |
| FULL_MODEL = "Full model" | |
| MOCK_LATENTS = "Mock latents" | |
| class MusicRole(Enum): | |
| """Enum for musical roles/layers.""" | |
| MELODY = "melody" | |
| BASS = "bass" | |
| HARMONY = "harmony" | |
| PAD = "pad" | |
| ACCENT = "accent" | |
| ATMOSPHERE = "atmosphere" | |
| class ScaleDefinition: | |
| """Represents a musical scale.""" | |
| name: str | |
| notes: List[int] | |
| description: str = "" | |
| def __post_init__(self): | |
| """Validate scale notes are within MIDI range.""" | |
| for note in self.notes: | |
| if not 0 <= note <= 127: | |
| raise ValueError(f"MIDI note {note} out of range (0-127)") | |
| class InstrumentMapping: | |
| """Maps a layer to an instrument and musical role.""" | |
| program: int # MIDI program number | |
| role: MusicRole | |
| channel: int | |
| name: str = "" | |
| def __post_init__(self): | |
| """Validate MIDI program and channel.""" | |
| if not 0 <= self.program <= 127: | |
| raise ValueError(f"MIDI program {self.program} out of range") | |
| if not 0 <= self.channel <= 15: | |
| raise ValueError(f"MIDI channel {self.channel} out of range") | |
| class GenerationConfig: | |
| """Complete configuration for music generation.""" | |
| model_name: str | |
| compute_mode: ComputeMode | |
| base_tempo: int | |
| velocity_range: Tuple[int, int] | |
| scale: ScaleDefinition | |
| num_layers_limit: int | |
| seed: int | |
| instrument_preset: str | |
| # Additional configuration options | |
| quantization_grid: int = 120 | |
| octave_range: int = 2 | |
| dynamics_curve: str = "linear" # linear, exponential, logarithmic | |
| def validate(self): | |
| """Validate configuration parameters.""" | |
| if not 1 <= self.base_tempo <= 2000: | |
| raise ValueError("Tempo must be between 1 and 2000") | |
| if not 1 <= self.velocity_range[0] < self.velocity_range[1] <= 127: | |
| raise ValueError("Invalid velocity range") | |
| if not 1 <= self.num_layers_limit <= 32: | |
| raise ValueError("Number of layers must be between 1 and 32") | |
| def to_dict(self) -> Dict: | |
| """Convert config to dictionary for serialization.""" | |
| return { | |
| "model_name": self.model_name, | |
| "compute_mode": self.compute_mode.value, | |
| "base_tempo": self.base_tempo, | |
| "velocity_range": self.velocity_range, | |
| "scale_name": self.scale.name, | |
| "scale_notes": self.scale.notes, | |
| "num_layers_limit": self.num_layers_limit, | |
| "seed": self.seed, | |
| "instrument_preset": self.instrument_preset, | |
| "quantization_grid": self.quantization_grid, | |
| "octave_range": self.octave_range, | |
| "dynamics_curve": self.dynamics_curve | |
| } | |
| def from_dict(cls, data: Dict, scale_manager: "ScaleManager") -> "GenerationConfig": | |
| """Create config from dictionary.""" | |
| scale = scale_manager.get_scale(data["scale_name"]) | |
| if scale is None: | |
| scale = ScaleDefinition(name="Custom", notes=data["scale_notes"]) | |
| return cls( | |
| model_name=data["model_name"], | |
| compute_mode=ComputeMode(data["compute_mode"]), | |
| base_tempo=data["base_tempo"], | |
| velocity_range=tuple(data["velocity_range"]), | |
| scale=scale, | |
| num_layers_limit=data["num_layers_limit"], | |
| seed=data["seed"], | |
| instrument_preset=data["instrument_preset"], | |
| quantization_grid=data.get("quantization_grid", 120), | |
| octave_range=data.get("octave_range", 2), | |
| dynamics_curve=data.get("dynamics_curve", "linear") | |
| ) | |
| class Latents: | |
| """Container for model latents.""" | |
| hidden_states: List[torch.Tensor] | |
| attentions: List[torch.Tensor] | |
| num_layers: int | |
| num_tokens: int | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| # Music Components | |
| class ScaleManager: | |
| """Manages musical scales and modes.""" | |
| def __init__(self): | |
| """Initialize with default scales.""" | |
| self.scales = { | |
| "C pentatonic": ScaleDefinition( | |
| "C pentatonic", | |
| [60, 62, 65, 67, 70, 72, 74, 77], | |
| "Major pentatonic scale" | |
| ), | |
| "C major": ScaleDefinition( | |
| "C major", | |
| [60, 62, 64, 65, 67, 69, 71, 72], | |
| "Major scale (Ionian mode)" | |
| ), | |
| "A minor": ScaleDefinition( | |
| "A minor", | |
| [57, 59, 60, 62, 64, 65, 67, 69], | |
| "Natural minor scale (Aeolian mode)" | |
| ), | |
| "D dorian": ScaleDefinition( | |
| "D dorian", | |
| [62, 64, 65, 67, 69, 71, 72, 74], | |
| "Dorian mode - minor with raised 6th" | |
| ), | |
| "E phrygian": ScaleDefinition( | |
| "E phrygian", | |
| [64, 65, 67, 69, 71, 72, 74, 76], | |
| "Phrygian mode - minor with lowered 2nd" | |
| ), | |
| "G mixolydian": ScaleDefinition( | |
| "G mixolydian", | |
| [67, 69, 71, 72, 74, 76, 77, 79], | |
| "Mixolydian mode - major with lowered 7th" | |
| ), | |
| "Blues scale": ScaleDefinition( | |
| "Blues scale", | |
| [60, 63, 65, 66, 67, 70, 72, 75], | |
| "Blues scale with blue notes" | |
| ), | |
| "Chromatic": ScaleDefinition( | |
| "Chromatic", | |
| list(range(60, 72)), | |
| "All 12 semitones" | |
| ) | |
| } | |
| def get_scale(self, name: str) -> Optional[ScaleDefinition]: | |
| """Get scale by name.""" | |
| return self.scales.get(name) | |
| def add_custom_scale(self, name: str, notes: List[int], description: str = "") -> ScaleDefinition: | |
| """Add a custom scale.""" | |
| scale = ScaleDefinition(name, notes, description) | |
| self.scales[name] = scale | |
| return scale | |
| def list_scales(self) -> List[str]: | |
| """Get list of available scale names.""" | |
| return list(self.scales.keys()) | |
| class InstrumentPresetManager: | |
| """Manages instrument presets for different musical styles.""" | |
| def __init__(self): | |
| """Initialize with default presets.""" | |
| self.presets = { | |
| "Ensemble (melody+bass+pad etc.)": [ | |
| InstrumentMapping(0, MusicRole.MELODY, 0, "Piano"), | |
| InstrumentMapping(33, MusicRole.BASS, 1, "Electric Bass"), | |
| InstrumentMapping(46, MusicRole.HARMONY, 2, "Harp"), | |
| InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"), | |
| InstrumentMapping(11, MusicRole.ACCENT, 4, "Vibraphone"), | |
| InstrumentMapping(89, MusicRole.ATMOSPHERE, 5, "Pad Warm") | |
| ], | |
| "Piano Trio (melody+bass+harmony)": [ | |
| InstrumentMapping(0, MusicRole.MELODY, 0, "Piano"), | |
| InstrumentMapping(33, MusicRole.BASS, 1, "Electric Bass"), | |
| InstrumentMapping(0, MusicRole.HARMONY, 2, "Piano"), | |
| InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"), | |
| InstrumentMapping(0, MusicRole.ACCENT, 4, "Piano"), | |
| InstrumentMapping(0, MusicRole.ATMOSPHERE, 5, "Piano") | |
| ], | |
| "Pads & Atmosphere": [ | |
| InstrumentMapping(48, MusicRole.PAD, 0, "String Ensemble"), | |
| InstrumentMapping(48, MusicRole.PAD, 1, "String Ensemble"), | |
| InstrumentMapping(89, MusicRole.ATMOSPHERE, 2, "Pad Warm"), | |
| InstrumentMapping(89, MusicRole.ATMOSPHERE, 3, "Pad Warm"), | |
| InstrumentMapping(46, MusicRole.HARMONY, 4, "Harp"), | |
| InstrumentMapping(11, MusicRole.ACCENT, 5, "Vibraphone") | |
| ], | |
| "Orchestral": [ | |
| InstrumentMapping(40, MusicRole.MELODY, 0, "Violin"), | |
| InstrumentMapping(42, MusicRole.BASS, 1, "Cello"), | |
| InstrumentMapping(46, MusicRole.HARMONY, 2, "Harp"), | |
| InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"), | |
| InstrumentMapping(73, MusicRole.ACCENT, 4, "Flute"), | |
| InstrumentMapping(49, MusicRole.ATMOSPHERE, 5, "Slow Strings") | |
| ], | |
| "Electronic": [ | |
| InstrumentMapping(80, MusicRole.MELODY, 0, "Lead Square"), | |
| InstrumentMapping(38, MusicRole.BASS, 1, "Synth Bass"), | |
| InstrumentMapping(81, MusicRole.HARMONY, 2, "Lead Sawtooth"), | |
| InstrumentMapping(90, MusicRole.PAD, 3, "Pad Polysynth"), | |
| InstrumentMapping(82, MusicRole.ACCENT, 4, "Lead Calliope"), | |
| InstrumentMapping(91, MusicRole.ATMOSPHERE, 5, "Pad Bowed") | |
| ] | |
| } | |
| def get_preset(self, name: str) -> List[InstrumentMapping]: | |
| """Get instrument preset by name.""" | |
| return self.presets.get(name, self.presets["Ensemble (melody+bass+pad etc.)"]) | |
| def list_presets(self) -> List[str]: | |
| """Get list of available preset names.""" | |
| return list(self.presets.keys()) | |
| # Music Generation Components | |
| class MusicMathUtils: | |
| """Utility class for music-related mathematical operations.""" | |
| def entropy(p: np.ndarray) -> float: | |
| """Calculate Shannon entropy of a probability distribution.""" | |
| p = p / (p.sum() + 1e-9) | |
| return float(-np.sum(p * np.log2(p + 1e-9))) | |
| def quantize_time(time_val: int, grid: int = 120) -> int: | |
| """Quantize time value to grid.""" | |
| return int(round(time_val / grid) * grid) | |
| def norm_to_scale(val: float, scale: np.ndarray, octave_range: int = 2) -> int: | |
| """Map normalized value to scale note with octave range.""" | |
| octave = int(abs(val) * octave_range) * 12 | |
| note_idx = int(abs(val * 100) % len(scale)) | |
| return int(scale[note_idx] + octave) | |
| def apply_dynamics_curve(value: float, curve_type: str = "linear") -> float: | |
| """Apply dynamics curve to a value.""" | |
| value = np.clip(value, 0, 1) | |
| if curve_type == "exponential": | |
| return value ** 2 | |
| elif curve_type == "logarithmic": | |
| return np.log1p(value * np.e) / np.log1p(np.e) | |
| else: # linear | |
| return value | |
| class NoteGenerator: | |
| """Generates notes based on neural network latents.""" | |
| # Role-specific frequency multipliers | |
| ROLE_FREQUENCIES = { | |
| MusicRole.MELODY: 2.0, | |
| MusicRole.BASS: 0.5, | |
| MusicRole.HARMONY: 1.5, | |
| MusicRole.PAD: 0.25, | |
| MusicRole.ACCENT: 3.0, | |
| MusicRole.ATMOSPHERE: 0.33 | |
| } | |
| # Role-specific weight distributions | |
| ROLE_WEIGHTS = { | |
| MusicRole.MELODY: np.array([0.4, 0.2, 0.2, 0.1, 0.1]), | |
| MusicRole.BASS: np.array([0.1, 0.4, 0.1, 0.3, 0.1]), | |
| MusicRole.HARMONY: np.array([0.2, 0.2, 0.3, 0.2, 0.1]), | |
| MusicRole.PAD: np.array([0.1, 0.3, 0.1, 0.1, 0.4]), | |
| MusicRole.ACCENT: np.array([0.5, 0.1, 0.2, 0.1, 0.1]), | |
| MusicRole.ATMOSPHERE: np.array([0.1, 0.2, 0.1, 0.2, 0.4]) | |
| } | |
| def __init__(self, config: GenerationConfig): | |
| """Initialize with generation configuration.""" | |
| self.config = config | |
| self.math_utils = MusicMathUtils() | |
| self.history: Dict[int, int] = {} | |
| def create_note_probability( | |
| self, | |
| layer_idx: int, | |
| token_idx: int, | |
| attention_val: float, | |
| hidden_state: np.ndarray, | |
| num_tokens: int, | |
| role: MusicRole | |
| ) -> float: | |
| """Calculate probability of playing a note based on multiple factors.""" | |
| # Base probability from attention | |
| base_prob = 1 / (1 + np.exp(-10 * (attention_val - 0.5))) | |
| # Temporal factor based on role frequency | |
| temporal_factor = 0.5 + 0.5 * np.sin( | |
| 2 * np.pi * self.ROLE_FREQUENCIES[role] * token_idx / max(1, num_tokens) | |
| ) | |
| # Energy factor from hidden state norm | |
| energy = np.linalg.norm(hidden_state) | |
| energy_factor = np.tanh(energy / 10) | |
| # Variance factor | |
| local_variance = np.var(hidden_state) | |
| variance_factor = 1 - np.exp(-local_variance) | |
| # Entropy factor | |
| state_entropy = self.math_utils.entropy(np.abs(hidden_state)) | |
| max_entropy = np.log2(max(2, hidden_state.shape[0])) | |
| entropy_factor = state_entropy / max_entropy | |
| # Combine factors with role-specific weights | |
| factors = np.array([base_prob, temporal_factor, energy_factor, variance_factor, entropy_factor]) | |
| weights = self.ROLE_WEIGHTS[role] | |
| combined_prob = float(np.dot(weights, factors)) | |
| # Add deterministic noise for variation | |
| noise_seed = layer_idx * 1000 + token_idx | |
| noise = 0.1 * (np.sin(noise_seed * 0.1) + np.cos(noise_seed * 0.23)) / 2 | |
| # Apply dynamics curve | |
| final_prob = (combined_prob + noise) ** 1.5 | |
| final_prob = self.math_utils.apply_dynamics_curve(final_prob, self.config.dynamics_curve) | |
| return float(np.clip(final_prob, 0, 1)) | |
| def should_play_note( | |
| self, | |
| layer_idx: int, | |
| token_idx: int, | |
| attention_val: float, | |
| hidden_state: np.ndarray, | |
| num_tokens: int, | |
| role: MusicRole | |
| ) -> bool: | |
| """Determine if a note should be played.""" | |
| prob = self.create_note_probability( | |
| layer_idx, token_idx, attention_val, hidden_state, num_tokens, role | |
| ) | |
| # Adjust probability based on silence duration | |
| if layer_idx in self.history: | |
| last_played = self.history[layer_idx] | |
| silence_duration = token_idx - last_played | |
| prob *= (1 + np.tanh(silence_duration / 5) * 0.5) | |
| # Stochastic decision | |
| play_note = np.random.random() < prob | |
| if play_note: | |
| self.history[layer_idx] = token_idx | |
| return play_note | |
| def generate_notes_for_role( | |
| self, | |
| role: MusicRole, | |
| hidden_state: np.ndarray, | |
| scale: np.ndarray | |
| ) -> List[int]: | |
| """Generate notes based on role and hidden state.""" | |
| if role == MusicRole.MELODY: | |
| note = self.math_utils.norm_to_scale( | |
| hidden_state[0], scale, octave_range=1 | |
| ) | |
| return [note] | |
| elif role == MusicRole.BASS: | |
| note = self.math_utils.norm_to_scale( | |
| hidden_state[0], scale, octave_range=0 | |
| ) - 12 | |
| return [note] | |
| elif role == MusicRole.HARMONY: | |
| return [ | |
| self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1) | |
| for i in range(0, min(2, len(hidden_state)), 1) | |
| ] | |
| elif role == MusicRole.PAD: | |
| return [ | |
| self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1) | |
| for i in range(0, min(3, len(hidden_state)), 2) | |
| ] | |
| elif role == MusicRole.ACCENT: | |
| note = self.math_utils.norm_to_scale( | |
| hidden_state[0], scale, octave_range=2 | |
| ) + 12 | |
| return [note] | |
| else: # ATMOSPHERE | |
| return [ | |
| self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1) | |
| for i in range(0, min(2, len(hidden_state)), 3) | |
| ] | |
| def calculate_velocity( | |
| self, | |
| role: MusicRole, | |
| attention_strength: float | |
| ) -> int: | |
| """Calculate note velocity based on role and attention.""" | |
| base_velocity = int( | |
| attention_strength * (self.config.velocity_range[1] - self.config.velocity_range[0]) | |
| + self.config.velocity_range[0] | |
| ) | |
| # Role-specific adjustments | |
| if role == MusicRole.MELODY: | |
| velocity = min(base_velocity + 10, 127) | |
| elif role == MusicRole.ACCENT: | |
| velocity = min(base_velocity + 20, 127) | |
| elif role in [MusicRole.PAD, MusicRole.ATMOSPHERE]: | |
| velocity = max(base_velocity - 10, 20) | |
| else: | |
| velocity = base_velocity | |
| return velocity | |
| def calculate_duration( | |
| self, | |
| role: MusicRole, | |
| attention_matrix: np.ndarray | |
| ) -> int: | |
| """Calculate note duration based on role and attention.""" | |
| if role in [MusicRole.PAD, MusicRole.ATMOSPHERE]: | |
| duration = self.config.base_tempo * 4 | |
| elif role == MusicRole.BASS: | |
| duration = self.config.base_tempo | |
| else: | |
| try: | |
| dur_factor = self.math_utils.entropy(attention_matrix.mean(axis=0)) / ( | |
| np.log2(attention_matrix.shape[-1]) + 1e-9 | |
| ) | |
| except Exception: | |
| dur_factor = 0.5 | |
| duration = self.math_utils.quantize_time( | |
| int(self.config.base_tempo * (0.5 + dur_factor * 1.5)), | |
| self.config.quantization_grid | |
| ) | |
| return duration | |
| # Model Interaction | |
| class LatentExtractor(ABC): | |
| """Abstract base class for latent extraction strategies.""" | |
| def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents: | |
| """Extract latents from text.""" | |
| pass | |
| class MockLatentExtractor(LatentExtractor): | |
| """Generate mock latents for testing without loading models.""" | |
| def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents: | |
| """Generate synthetic latents based on text.""" | |
| # Simulate token count based on text length | |
| tokens = max(16, min(128, len(text.split()) * 4)) | |
| layers = min(config.num_layers_limit, 6) | |
| # Generate deterministic but varied latents based on text | |
| np.random.seed(hash(text) % 2**32) | |
| hidden_states = [ | |
| torch.randn(1, tokens, 128) for _ in range(layers) | |
| ] | |
| attentions = [ | |
| torch.rand(1, 8, tokens, tokens) for _ in range(layers) | |
| ] | |
| metadata = { | |
| "mode": "mock", | |
| "text_length": len(text), | |
| "generated_tokens": tokens, | |
| "generated_layers": layers | |
| } | |
| return Latents( | |
| hidden_states=hidden_states, | |
| attentions=attentions, | |
| num_layers=layers, | |
| num_tokens=tokens, | |
| metadata=metadata | |
| ) | |
| class ModelLatentExtractor(LatentExtractor): | |
| """Extract real latents from transformer models.""" | |
| def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents: | |
| """Extract latents from a real transformer model.""" | |
| model_name = config.model_name | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if tokenizer.pad_token is None and tokenizer.eos_token is not None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Configure model loading | |
| load_kwargs = { | |
| "output_hidden_states": True, | |
| "output_attentions": True, | |
| "device_map": "cuda" if torch.cuda.is_available() else "cpu", | |
| } | |
| # Set appropriate dtype | |
| try: | |
| load_kwargs["torch_dtype"] = ( | |
| torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| ) | |
| except Exception: | |
| pass | |
| # Load model | |
| model = AutoModel.from_pretrained(model_name, **load_kwargs) | |
| # Tokenize input | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| device = next(model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Get model outputs | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| hidden_states = list(outputs.hidden_states) | |
| attentions = list(outputs.attentions) | |
| # Move to CPU to free VRAM | |
| hidden_states = [hs.to("cpu") for hs in hidden_states] | |
| attentions = [att.to("cpu") for att in attentions] | |
| # Limit layers | |
| layers = min(config.num_layers_limit, len(hidden_states)) | |
| tokens = hidden_states[0].shape[1] | |
| # Clean up | |
| try: | |
| del model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| except Exception: | |
| pass | |
| metadata = { | |
| "mode": "full_model", | |
| "model_name": model_name, | |
| "actual_layers": len(hidden_states), | |
| "used_layers": layers, | |
| "tokens": tokens | |
| } | |
| return Latents( | |
| hidden_states=hidden_states[:layers], | |
| attentions=attentions[:layers], | |
| num_layers=layers, | |
| num_tokens=tokens, | |
| metadata=metadata | |
| ) | |
| class LatentExtractorFactory: | |
| """Factory for creating appropriate latent extractors.""" | |
| def create(compute_mode: ComputeMode) -> LatentExtractor: | |
| """Create a latent extractor based on compute mode.""" | |
| if compute_mode == ComputeMode.MOCK_LATENTS: | |
| return MockLatentExtractor() | |
| else: | |
| return ModelLatentExtractor() | |
| # MIDI Generation | |
| class MIDIRenderer: | |
| """Renders MIDI files from latents.""" | |
| def __init__(self, config: GenerationConfig, instrument_manager: InstrumentPresetManager): | |
| """Initialize MIDI renderer.""" | |
| self.config = config | |
| self.instrument_manager = instrument_manager | |
| self.note_generator = NoteGenerator(config) | |
| self.math_utils = MusicMathUtils() | |
| def render(self, latents: Latents) -> Tuple[bytes, Dict[str, Any]]: | |
| """Render MIDI from latents.""" | |
| # Set random seeds for reproducibility | |
| np.random.seed(self.config.seed) | |
| random.seed(self.config.seed) | |
| torch.manual_seed(self.config.seed) | |
| # Prepare data | |
| scale = np.array(self.config.scale.notes, dtype=int) | |
| num_layers = latents.num_layers | |
| num_tokens = latents.num_tokens | |
| # Convert tensors to numpy | |
| hidden_states = [ | |
| hs.float().numpy() if isinstance(hs, torch.Tensor) else hs | |
| for hs in latents.hidden_states | |
| ] | |
| attentions = [ | |
| att.float().numpy() if isinstance(att, torch.Tensor) else att | |
| for att in latents.attentions | |
| ] | |
| # Get instrument mappings | |
| instrument_mappings = self.instrument_manager.get_preset(self.config.instrument_preset) | |
| # Create MIDI file and tracks | |
| midi_file = MidiFile() | |
| tracks = self._create_tracks(midi_file, num_layers, instrument_mappings) | |
| # Generate notes | |
| stats = self._generate_notes( | |
| tracks, hidden_states, attentions, | |
| scale, num_tokens, instrument_mappings | |
| ) | |
| # Convert to bytes | |
| bio = io.BytesIO() | |
| midi_file.save(file=bio) | |
| bio.seek(0) | |
| # Prepare metadata | |
| metadata = { | |
| "config": self.config.to_dict(), | |
| "latents_info": latents.metadata, | |
| "stats": stats, | |
| "timestamp": time.time() | |
| } | |
| return bio.read(), metadata | |
| def _create_tracks( | |
| self, | |
| midi_file: MidiFile, | |
| num_layers: int, | |
| instrument_mappings: List[InstrumentMapping] | |
| ) -> List[MidiTrack]: | |
| """Create MIDI tracks with instrument assignments.""" | |
| tracks = [] | |
| for layer_idx in range(num_layers): | |
| track = MidiTrack() | |
| midi_file.tracks.append(track) | |
| tracks.append(track) | |
| # Get instrument mapping for this layer | |
| if layer_idx < len(instrument_mappings): | |
| mapping = instrument_mappings[layer_idx] | |
| else: | |
| # Default to piano if not enough mappings | |
| mapping = InstrumentMapping(0, MusicRole.MELODY, layer_idx % 16) | |
| # Set instrument | |
| track.append(Message( | |
| "program_change", | |
| program=mapping.program, | |
| time=0, | |
| channel=mapping.channel | |
| )) | |
| # Add track name | |
| if mapping.name: | |
| track.append(mido.MetaMessage( | |
| "track_name", | |
| name=f"{mapping.name} - {mapping.role.value}", | |
| time=0 | |
| )) | |
| return tracks | |
| def _generate_notes( | |
| self, | |
| tracks: List[MidiTrack], | |
| hidden_states: List[np.ndarray], | |
| attentions: List[np.ndarray], | |
| scale: np.ndarray, | |
| num_tokens: int, | |
| instrument_mappings: List[InstrumentMapping] | |
| ) -> Dict[str, Any]: | |
| """Generate notes for all tracks.""" | |
| current_time = [0] * len(tracks) | |
| notes_count = [0] * len(tracks) | |
| for token_idx in range(num_tokens): | |
| # Update time periodically | |
| if token_idx > 0 and token_idx % 4 == 0: | |
| for layer_idx in range(len(tracks)): | |
| current_time[layer_idx] += self.config.base_tempo | |
| # Calculate panning | |
| pan = 64 + int(32 * np.sin(token_idx * math.pi / max(1, num_tokens))) | |
| # Generate notes for each layer | |
| for layer_idx in range(len(tracks)): | |
| if layer_idx >= len(instrument_mappings): | |
| continue | |
| mapping = instrument_mappings[layer_idx] | |
| # Get attention and hidden state | |
| attn_matrix = attentions[min(layer_idx, len(attentions) - 1)][0, :, token_idx, :] | |
| attention_strength = float(np.mean(attn_matrix)) | |
| layer_vec = hidden_states[layer_idx][0, token_idx] | |
| # Check if note should be played | |
| if not self.note_generator.should_play_note( | |
| layer_idx, token_idx, attention_strength, | |
| layer_vec, num_tokens, mapping.role | |
| ): | |
| continue | |
| # Generate notes | |
| notes_to_play = self.note_generator.generate_notes_for_role( | |
| mapping.role, layer_vec, scale | |
| ) | |
| # Calculate velocity and duration | |
| velocity = self.note_generator.calculate_velocity( | |
| mapping.role, attention_strength | |
| ) | |
| duration = self.note_generator.calculate_duration( | |
| mapping.role, attn_matrix | |
| ) | |
| # Add notes to track | |
| for note in notes_to_play: | |
| note = max(21, min(108, int(note))) # Clamp to piano range | |
| tracks[layer_idx].append(Message( | |
| "note_on", | |
| note=note, | |
| velocity=velocity, | |
| time=current_time[layer_idx], | |
| channel=mapping.channel | |
| )) | |
| tracks[layer_idx].append(Message( | |
| "note_off", | |
| note=note, | |
| velocity=0, | |
| time=duration, | |
| channel=mapping.channel | |
| )) | |
| current_time[layer_idx] = 0 | |
| notes_count[layer_idx] += 1 | |
| # Set panning on first token | |
| if token_idx == 0: | |
| tracks[layer_idx].append(Message( | |
| "control_change", | |
| control=10, | |
| value=pan, | |
| time=0, | |
| channel=mapping.channel | |
| )) | |
| return { | |
| "num_layers": len(tracks), | |
| "num_tokens": num_tokens, | |
| "notes_per_layer": notes_count, | |
| "total_notes": int(sum(notes_count)), | |
| "tempo_ticks_per_beat": int(self.config.base_tempo), | |
| "scale": list(map(int, scale.tolist())), | |
| } | |
| # Main Orchestrator | |
| class LLMForestOrchestra: | |
| """Main orchestrator class that coordinates the entire pipeline.""" | |
| DEFAULT_MODEL = "unsloth/Qwen3-14B-Base" | |
| def __init__(self): | |
| """Initialize the orchestra.""" | |
| self.scale_manager = ScaleManager() | |
| self.instrument_manager = InstrumentPresetManager() | |
| self.saved_configs: Dict[str, GenerationConfig] = {} | |
| def generate( | |
| self, | |
| text: str, | |
| model_name: str, | |
| compute_mode: str, | |
| base_tempo: int, | |
| velocity_range: Tuple[int, int], | |
| scale_name: str, | |
| custom_scale_notes: Optional[List[int]], | |
| num_layers: int, | |
| instrument_preset: str, | |
| seed: int, | |
| quantization_grid: int = 120, | |
| octave_range: int = 2, | |
| dynamics_curve: str = "linear" | |
| ) -> Tuple[str, Dict[str, Any]]: | |
| """Generate MIDI from text input.""" | |
| # Get or create scale | |
| if scale_name == "Custom": | |
| if not custom_scale_notes: | |
| raise ValueError("Custom scale requires note list") | |
| scale = ScaleDefinition("Custom", custom_scale_notes) | |
| else: | |
| scale = self.scale_manager.get_scale(scale_name) | |
| if scale is None: | |
| raise ValueError(f"Unknown scale: {scale_name}") | |
| # Create configuration | |
| config = GenerationConfig( | |
| model_name=model_name or self.DEFAULT_MODEL, | |
| compute_mode=ComputeMode(compute_mode), | |
| base_tempo=base_tempo, | |
| velocity_range=velocity_range, | |
| scale=scale, | |
| num_layers_limit=num_layers, | |
| seed=seed, | |
| instrument_preset=instrument_preset, | |
| quantization_grid=quantization_grid, | |
| octave_range=octave_range, | |
| dynamics_curve=dynamics_curve | |
| ) | |
| # Validate configuration | |
| config.validate() | |
| # Extract latents | |
| extractor = LatentExtractorFactory.create(config.compute_mode) | |
| latents = extractor.extract(text, config) | |
| # Render MIDI | |
| renderer = MIDIRenderer(config, self.instrument_manager) | |
| midi_bytes, metadata = renderer.render(latents) | |
| # Save MIDI file | |
| filename = f"llm_forest_orchestra_{uuid.uuid4().hex[:8]}.mid" | |
| with open(filename, "wb") as f: | |
| f.write(midi_bytes) | |
| return filename, metadata | |
| def save_config(self, name: str, config: GenerationConfig): | |
| """Save a configuration for later use.""" | |
| self.saved_configs[name] = config | |
| def load_config(self, name: str) -> Optional[GenerationConfig]: | |
| """Load a saved configuration.""" | |
| return self.saved_configs.get(name) | |
| def export_config(self, config: GenerationConfig, filepath: str): | |
| """Export configuration to JSON file.""" | |
| with open(filepath, "w") as f: | |
| json.dump(config.to_dict(), f, indent=2) | |
| def import_config(self, filepath: str) -> GenerationConfig: | |
| """Import configuration from JSON file.""" | |
| with open(filepath, "r") as f: | |
| data = json.load(f) | |
| return GenerationConfig.from_dict(data, self.scale_manager) | |
| # Gradio UI | |
| class GradioInterface: | |
| """Manages the Gradio user interface.""" | |
| DESCRIPTION = """ | |
| # π² LLM Forest Orchestra β Sonify Transformer Internals | |
| Transform the hidden states and attention patterns of language models into multi-layered musical compositions. | |
| ## π Inspiration | |
| This project is inspired by the way **mushrooms and mycelial networks in forests** | |
| connect plants and trees, forming a living web of communication and resource sharing. | |
| These connections, can be turned into ethereal music. | |
| Just as signals move through these hidden connections, transformer models also | |
| pass hidden states and attentions across their layers. Here, those hidden | |
| connections are translated into **music**, analogous to the forest's secret orchestra. | |
| ## Features | |
| - **Two compute modes**: Full model (GPU) or Mock latents (CPU-friendly) | |
| - **Multiple musical scales**: From pentatonic to chromatic | |
| - **Instrument presets**: Orchestral, electronic, ensemble, and more | |
| - **Advanced controls**: Dynamics curves, quantization, velocity ranges | |
| - **Export**: Standard MIDI files for further editing in your DAW | |
| """ | |
| EXAMPLE_TEXT = """Joy cascades in golden waterfalls, crashing into pools of melancholy blue. | |
| Anger burns red through veins of marble, while serenity floats on clouds of softest grey. | |
| Love pulses in waves of crimson and rose, intertwining with longing's purple haze. | |
| Each feeling resonates at its own frequency, painting music across the soul's canvas.""" | |
| def __init__(self, orchestra: LLMForestOrchestra): | |
| """Initialize the interface.""" | |
| self.orchestra = orchestra | |
| def create_interface(self) -> gr.Blocks: | |
| """Create the Gradio interface.""" | |
| with gr.Blocks(title="LLM Forest Orchestra", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(self.DESCRIPTION) | |
| with gr.Tabs(): | |
| with gr.TabItem("π΅ Generate Music"): | |
| self._create_generation_tab() | |
| return demo | |
| def _create_generation_tab(self): | |
| """Create the main generation tab.""" | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox( | |
| value=self.EXAMPLE_TEXT, | |
| label="Input Text", | |
| lines=8, | |
| placeholder="Enter text to sonify..." | |
| ) | |
| model_name = gr.Textbox( | |
| value=self.orchestra.DEFAULT_MODEL, | |
| label="Hugging Face Model", | |
| info="Model must support output_hidden_states and output_attentions" | |
| ) | |
| compute_mode = gr.Radio( | |
| choices=["Full model", "Mock latents"], | |
| value="Mock latents", | |
| label="Compute Mode", | |
| info="Mock latents for quick CPU-only demo" | |
| ) | |
| with gr.Row(): | |
| instrument_preset = gr.Dropdown( | |
| choices=self.orchestra.instrument_manager.list_presets(), | |
| value="Ensemble (melody+bass+pad etc.)", | |
| label="Instrument Preset" | |
| ) | |
| scale_choice = gr.Dropdown( | |
| choices=self.orchestra.scale_manager.list_scales() + ["Custom"], | |
| value="C pentatonic", | |
| label="Musical Scale" | |
| ) | |
| custom_scale = gr.Textbox( | |
| value="", | |
| label="Custom Scale Notes", | |
| placeholder="60,62,65,67,70", | |
| visible=False | |
| ) | |
| with gr.Row(): | |
| base_tempo = gr.Slider( | |
| 120, 960, | |
| value=480, | |
| step=1, | |
| label="Tempo (ticks per beat)" | |
| ) | |
| num_layers = gr.Slider( | |
| 1, 6, | |
| value=6, | |
| step=1, | |
| label="Max Layers" | |
| ) | |
| with gr.Row(): | |
| velocity_low = gr.Slider( | |
| 1, 126, | |
| value=40, | |
| step=1, | |
| label="Min Velocity" | |
| ) | |
| velocity_high = gr.Slider( | |
| 2, 127, | |
| value=90, | |
| step=1, | |
| label="Max Velocity" | |
| ) | |
| seed = gr.Number( | |
| value=42, | |
| precision=0, | |
| label="Random Seed" | |
| ) | |
| generate_btn = gr.Button( | |
| "πΌ Generate MIDI", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| midi_output = gr.File( | |
| label="Generated MIDI File", | |
| file_types=[".mid", ".midi"] | |
| ) | |
| stats_display = gr.Markdown(label="Quick Stats") | |
| metadata_json = gr.Code( | |
| label="Metadata (JSON)", | |
| language="json" | |
| ) | |
| with gr.Row(): | |
| play_instructions = gr.Markdown( | |
| """ | |
| ### π§ How to Play | |
| 1. Download the MIDI file | |
| 2. Open in any DAW or MIDI player | |
| 3. Adjust instruments and effects as desired | |
| 4. Export to audio format | |
| """ | |
| ) | |
| # Set up interactions | |
| def update_custom_scale_visibility(choice): | |
| return gr.update(visible=(choice == "Custom")) | |
| scale_choice.change( | |
| update_custom_scale_visibility, | |
| inputs=[scale_choice], | |
| outputs=[custom_scale] | |
| ) | |
| def generate_wrapper( | |
| text, model_name, compute_mode, base_tempo, | |
| velocity_low, velocity_high, scale_choice, | |
| custom_scale, num_layers, instrument_preset, seed | |
| ): | |
| """Wrapper for generation with error handling.""" | |
| try: | |
| # Parse custom scale if needed | |
| custom_notes = None | |
| if scale_choice == "Custom" and custom_scale: | |
| custom_notes = [int(x.strip()) for x in custom_scale.split(",")] | |
| # Generate | |
| filename, metadata = self.orchestra.generate( | |
| text=text, | |
| model_name=model_name, | |
| compute_mode=compute_mode, | |
| base_tempo=int(base_tempo), | |
| velocity_range=(int(velocity_low), int(velocity_high)), | |
| scale_name=scale_choice, | |
| custom_scale_notes=custom_notes, | |
| num_layers=int(num_layers), | |
| instrument_preset=instrument_preset, | |
| seed=int(seed) | |
| ) | |
| # Format stats | |
| stats = metadata.get("stats", {}) | |
| stats_text = f""" | |
| ### Generation Statistics | |
| - **Layers Used**: {stats.get('num_layers', 'N/A')} | |
| - **Tokens Processed**: {stats.get('num_tokens', 'N/A')} | |
| - **Total Notes**: {stats.get('total_notes', 'N/A')} | |
| - **Notes per Layer**: {stats.get('notes_per_layer', [])} | |
| - **Scale**: {stats.get('scale', [])} | |
| - **Tempo**: {stats.get('tempo_ticks_per_beat', 'N/A')} ticks/beat | |
| """ | |
| return filename, stats_text, json.dumps(metadata, indent=2) | |
| except Exception as e: | |
| error_msg = f"### β Error\n{str(e)}" | |
| return None, error_msg, json.dumps({"error": str(e)}, indent=2) | |
| generate_btn.click( | |
| fn=generate_wrapper, | |
| inputs=[ | |
| text_input, model_name, compute_mode, base_tempo, | |
| velocity_low, velocity_high, scale_choice, | |
| custom_scale, num_layers, instrument_preset, seed | |
| ], | |
| outputs=[midi_output, stats_display, metadata_json] | |
| ) | |
| # Main Entry Point | |
| def main(): | |
| """Main entry point for the application.""" | |
| # Initialize orchestra | |
| orchestra = LLMForestOrchestra() | |
| # Create interface | |
| interface = GradioInterface(orchestra) | |
| demo = interface.create_interface() | |
| # Launch | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() |