Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| import matplotlib.pyplot as plt | |
| import os | |
| from typing import List, Tuple | |
| from config import LOGS_DIR | |
| ##Some utils: | |
| def load_audio_files(file_paths: List[str]) -> List[Tuple[torch.Tensor, int]]: | |
| """ | |
| Load multiple audio files and ensure they have the same length. | |
| Args: | |
| file_paths: List of paths to audio files | |
| Returns: | |
| List of tuples containing audio data and sample rate | |
| """ | |
| audio_data = [] | |
| for path in file_paths: | |
| # Load audio file | |
| waveform, sample_rate = torchaudio.load(path) | |
| # Convert to mono if stereo | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| audio_data.append((waveform.squeeze(), sample_rate)) | |
| # Verify all audio files have the same length and sample rate | |
| lengths = [len(audio) for audio, _ in audio_data] | |
| sample_rates = [sr for _, sr in audio_data] | |
| if len(set(lengths)) > 1: | |
| raise ValueError(f"Audio files have different lengths: {lengths}") | |
| if len(set(sample_rates)) > 1: | |
| raise ValueError(f"Audio files have different sample rates: {sample_rates}") | |
| return audio_data | |
| def normalize_audio_volumes(audio_data: List[Tuple[torch.Tensor, int]]) -> List[Tuple[torch.Tensor, int]]: | |
| """ | |
| Normalize the volume of each audio file to have the same energy level. | |
| Args: | |
| audio_data: List of tuples containing audio data and sample rate | |
| Returns: | |
| List of tuples containing normalized audio data and sample rate | |
| """ | |
| normalized_data = [] | |
| # Calculate RMS (Root Mean Square) for each audio | |
| rms_values = [] | |
| for audio, sr in audio_data: | |
| # Calculate energy (squared amplitude) | |
| energy = torch.mean(audio ** 2) | |
| # Calculate RMS (square root of mean energy) | |
| rms = torch.sqrt(energy) | |
| rms_values.append(rms) | |
| # Find the target RMS (we'll use the median to avoid outliers) | |
| target_rms = torch.median(torch.tensor(rms_values)) | |
| # Normalize each audio to the target RMS | |
| for (audio, sr), rms in zip(audio_data, rms_values): | |
| if rms > 0: # Avoid division by zero | |
| # Calculate scaling factor | |
| scaling_factor = target_rms / rms | |
| # Apply scaling | |
| normalized_audio = audio * scaling_factor | |
| else: | |
| normalized_audio = audio | |
| normalized_data.append((normalized_audio, sr)) | |
| return normalized_data | |
| def plot_energy_comparison(original_metrics: List[dict], normalized_metrics: List[dict], file_names: List[str], output_path: str = "./logs/energy_comparison.png") -> None: | |
| """ | |
| Plot a comparison of energy metrics before and after normalization. | |
| Args: | |
| original_metrics: List of dictionaries containing metrics for original audio | |
| normalized_metrics: List of dictionaries containing metrics for normalized audio | |
| file_names: List of audio file names | |
| output_path: Path to save the plot | |
| """ | |
| fig, axs = plt.subplots(2, 2, figsize=(14, 10)) | |
| # Extract metrics | |
| orig_rms = [m['rms'] for m in original_metrics] | |
| norm_rms = [m['rms'] for m in normalized_metrics] | |
| orig_peak = [m['peak'] for m in original_metrics] | |
| norm_peak = [m['peak'] for m in normalized_metrics] | |
| orig_dr = [m['dynamic_range_db'] for m in original_metrics] | |
| norm_dr = [m['dynamic_range_db'] for m in normalized_metrics] | |
| orig_cf = [m['crest_factor'] for m in original_metrics] | |
| norm_cf = [m['crest_factor'] for m in normalized_metrics] | |
| # Prepare x-axis | |
| x = np.arange(len(file_names)) | |
| width = 0.35 | |
| # Plot RMS (volume) | |
| axs[0, 0].bar(x - width/2, orig_rms, width, label='Original') | |
| axs[0, 0].bar(x + width/2, norm_rms, width, label='Normalized') | |
| axs[0, 0].set_title('RMS Energy (Volume)') | |
| axs[0, 0].set_xticks(x) | |
| axs[0, 0].set_xticklabels(file_names, rotation=45, ha='right') | |
| axs[0, 0].set_ylabel('RMS Value') | |
| axs[0, 0].legend() | |
| # Plot Peak Amplitude | |
| axs[0, 1].bar(x - width/2, orig_peak, width, label='Original') | |
| axs[0, 1].bar(x + width/2, norm_peak, width, label='Normalized') | |
| axs[0, 1].set_title('Peak Amplitude') | |
| axs[0, 1].set_xticks(x) | |
| axs[0, 1].set_xticklabels(file_names, rotation=45, ha='right') | |
| axs[0, 1].set_ylabel('Peak Value') | |
| axs[0, 1].legend() | |
| # Plot Dynamic Range | |
| axs[1, 0].bar(x - width/2, orig_dr, width, label='Original') | |
| axs[1, 0].bar(x + width/2, norm_dr, width, label='Normalized') | |
| axs[1, 0].set_title('Dynamic Range (dB)') | |
| axs[1, 0].set_xticks(x) | |
| axs[1, 0].set_xticklabels(file_names, rotation=45, ha='right') | |
| axs[1, 0].set_ylabel('dB') | |
| axs[1, 0].legend() | |
| # Plot Crest Factor | |
| axs[1, 1].bar(x - width/2, orig_cf, width, label='Original') | |
| axs[1, 1].bar(x + width/2, norm_cf, width, label='Normalized') | |
| axs[1, 1].set_title('Crest Factor (Peak-to-RMS Ratio)') | |
| axs[1, 1].set_xticks(x) | |
| axs[1, 1].set_xticklabels(file_names, rotation=45, ha='right') | |
| axs[1, 1].set_ylabel('Ratio') | |
| axs[1, 1].legend() | |
| plt.tight_layout() | |
| # Create directory if it doesn't exist | |
| os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) | |
| # Save the plot | |
| plt.savefig(output_path) | |
| plt.close() | |
| def calculate_audio_metrics(audio_data: List[Tuple[torch.Tensor, int]]) -> List[dict]: | |
| """ | |
| Calculate various audio metrics for each audio file. | |
| Args: | |
| audio_data: List of tuples containing audio data and sample rate | |
| Returns: | |
| List of dictionaries containing metrics | |
| """ | |
| metrics = [] | |
| for audio, sr in audio_data: | |
| # Calculate RMS (Root Mean Square) | |
| energy = torch.mean(audio ** 2) | |
| rms = torch.sqrt(energy) | |
| # Calculate peak amplitude | |
| peak = torch.max(torch.abs(audio)) | |
| # Calculate dynamic range | |
| if torch.min(torch.abs(audio[audio != 0])) > 0: | |
| min_non_zero = torch.min(torch.abs(audio[audio != 0])) | |
| dynamic_range_db = 20 * torch.log10(peak / min_non_zero) | |
| else: | |
| dynamic_range_db = torch.tensor(float('inf')) | |
| # Calculate crest factor (peak to RMS ratio) | |
| crest_factor = peak / rms if rms > 0 else torch.tensor(float('inf')) | |
| metrics.append({ | |
| 'rms': rms.item(), | |
| 'peak': peak.item(), | |
| 'dynamic_range_db': dynamic_range_db.item() if not torch.isinf(dynamic_range_db) else float('inf'), | |
| 'crest_factor': crest_factor.item() if not torch.isinf(crest_factor) else float('inf') | |
| }) | |
| return metrics | |
| def create_weighted_composite( | |
| audio_data: List[Tuple[torch.Tensor, int]], | |
| weights: List[float] | |
| ) -> torch.Tensor: | |
| """ | |
| Create a weighted composite of multiple audio files. | |
| Args: | |
| audio_data: List of tuples containing audio data and sample rate | |
| weights: List of weights for each audio file | |
| Returns: | |
| Weighted composite audio data | |
| """ | |
| if len(audio_data) != len(weights): | |
| raise ValueError("Number of audio files and weights must match") | |
| # Normalize weights to sum to 1 | |
| weights = torch.tensor(weights) / sum(weights) | |
| # Initialize composite audio with zeros | |
| composite = torch.zeros_like(audio_data[0][0]) | |
| # Add weighted audio data | |
| for (audio, _), weight in zip(audio_data, weights): | |
| composite += audio * weight | |
| # Normalize to prevent clipping | |
| max_val = torch.max(torch.abs(composite)) | |
| if max_val > 1.0: | |
| composite = composite / max_val | |
| return composite | |
| def create_melspectrograms( | |
| audio_data: List[Tuple[torch.Tensor, int]], | |
| composite: torch.Tensor, | |
| sr: int | |
| ) -> List[torch.Tensor]: | |
| """ | |
| Create melspectrograms for individual audio files and the composite. | |
| Args: | |
| audio_data: List of tuples containing audio data and sample rate | |
| composite: Composite audio data | |
| sr: Sample rate | |
| Returns: | |
| List of melspectrogram data | |
| """ | |
| specs = [] | |
| # Create mel spectrogram transform | |
| mel_transform = T.MelSpectrogram( | |
| sample_rate=sr, | |
| n_fft=2048, | |
| win_length=2048, | |
| hop_length=512, | |
| n_mels=128, | |
| f_max=8000 | |
| ) | |
| # Generate spectrograms for individual audio files | |
| for audio, _ in audio_data: | |
| melspec = mel_transform(audio) | |
| specs.append(melspec) | |
| # Generate spectrogram for composite audio | |
| composite_melspec = mel_transform(composite) | |
| specs.append(composite_melspec) | |
| return specs | |
| def plot_melspectrograms( | |
| specs: List[torch.Tensor], | |
| sr: int, | |
| file_names: List[str], | |
| weights: List[float], | |
| output_path: str = "melspectrograms.png" | |
| ) -> None: | |
| """ | |
| Plot melspectrograms for individual audio files and the composite. | |
| Args: | |
| specs: List of melspectrogram data | |
| sr: Sample rate | |
| file_names: List of audio file names | |
| weights: List of weights for each audio file | |
| output_path: Path to save the plot | |
| """ | |
| fig, axs = plt.subplots(len(specs), 1, figsize=(12, 4 * len(specs))) | |
| # Create labels for the plots | |
| labels = [f"{name} (weight: {weight:.2f})" for name, weight in zip(file_names, weights)] | |
| labels.append("Composite.wav") | |
| # Convert to dB scale (similar to librosa's power_to_db) | |
| def power_to_db(spec): | |
| return 10 * torch.log10(spec + 1e-10) | |
| # Plot each melspectrogram | |
| for i, (spec, label) in enumerate(zip(specs, labels)): | |
| spec_db = power_to_db(spec).numpy().squeeze() | |
| # For single subplot case | |
| if len(specs) == 1: | |
| ax = axs | |
| else: | |
| ax = axs[i] | |
| img = ax.imshow( | |
| spec_db, | |
| aspect='auto', | |
| origin='lower', | |
| interpolation='none', | |
| extent=[0, spec_db.shape[1], 0, sr/2] | |
| ) | |
| ax.set_title(label) | |
| ax.set_ylabel('Frequency (Hz)') | |
| ax.set_xlabel('Time Frames') | |
| # No colorbar as requested | |
| plt.tight_layout() | |
| # Create directory if it doesn't exist | |
| os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) | |
| # Save the plot | |
| plt.savefig(output_path,dpi=300) | |
| plt.close() | |
| def compose_audio( | |
| file_paths: List[str], | |
| weights: List[float], | |
| output_audio_path: str = os.path.join(LOGS_DIR, "composite.wav"), | |
| output_plot_path: str = os.path.join(LOGS_DIR, "plot/melspectrograms.png"), | |
| energy_plot_path: str = os.path.join(LOGS_DIR, "plot/energy_comparison.png") | |
| ) -> None: | |
| """ | |
| Main function to process audio files and create visualizations. | |
| Args: | |
| file_paths: List of paths to audio files (supports 4 audio files) | |
| weights: List of weights for each audio file | |
| output_audio_path: Path to save the composite audio | |
| output_plot_path: Path to save the melspectrogram plot | |
| energy_plot_path: Path to save the energy comparison plot | |
| """ | |
| # Load audio files | |
| audio_data = load_audio_files(file_paths) | |
| # # Calculate metrics for original audio | |
| print("Calculating metrics for original audio...") | |
| original_metrics = calculate_audio_metrics(audio_data) | |
| # Normalize audio volumes to have same energy level | |
| print("Normalizing audio volumes...") | |
| normalized_audio_data = normalize_audio_volumes(audio_data) | |
| # Calculate metrics for normalized audio | |
| print("Calculating metrics for normalized audio...") | |
| normalized_metrics = calculate_audio_metrics(normalized_audio_data) | |
| # Print energy comparison | |
| print("\nAudio Energy Comparison (RMS values):") | |
| print("-" * 50) | |
| print(f"{'File':<20} {'Original':<15} {'Normalized':<15} {'Scaling Factor':<15}") | |
| print("-" * 50) | |
| for i, path in enumerate(file_paths): | |
| file_name = path.split("/")[-1] | |
| orig_rms = original_metrics[i]['rms'] | |
| norm_rms = normalized_metrics[i]['rms'] | |
| scaling = norm_rms / orig_rms if orig_rms > 0 else float('inf') | |
| print(f"{file_name[:20]:<20} {orig_rms:<15.6f} {norm_rms:<15.6f} {scaling:<15.6f}") | |
| # Create energy comparison plot | |
| print("\nCreating energy comparison plot...") | |
| file_names = [path.split("/")[-1] for path in file_paths] | |
| plot_energy_comparison(original_metrics, normalized_metrics, file_names, energy_plot_path) | |
| # Get sample rate (all files have the same sample rate) | |
| sr = normalized_audio_data[0][1] | |
| # Create weighted composite | |
| print("\nCreating weighted composite...") | |
| composite = create_weighted_composite(normalized_audio_data, weights) | |
| # Create directory if it doesn't exist | |
| os.makedirs(os.path.dirname(output_audio_path) or '.', exist_ok=True) | |
| # Save composite audio | |
| print("Saving composite audio...") | |
| torchaudio.save(output_audio_path, composite.unsqueeze(0), sr) | |
| # Create melspectrograms for normalized audio (not original) | |
| print("Creating melspectrograms for normalized audio...") | |
| specs = create_melspectrograms(normalized_audio_data, composite, sr) | |
| # Get file names without path | |
| labeled_file_names = [path.split("/")[-1] for path in file_paths] | |
| # Plot melspectrograms | |
| print("Plotting melspectrograms...") | |
| plot_melspectrograms(specs, sr, labeled_file_names, weights, output_plot_path) | |
| print(f"\nComposite audio saved to {output_audio_path}") | |
| print(f"Melspectrograms saved to {output_plot_path}") | |
| print(f"Energy comparison saved to {energy_plot_path}") | |
| print(f"Composite audio saved to {output_audio_path}") | |
| print(f"Melspectrograms saved to {output_plot_path}") | |
| # if __name__ == "__main__": | |
| # import argparse | |
| # parser = argparse.ArgumentParser(description="Mix audio files with weights and create melspectrograms") | |
| # parser.add_argument("--files", nargs="+", required=True, help="Paths to audio files") | |
| # parser.add_argument("--weights", nargs="+", type=float, required=True, help="Weights for each audio file") | |
| # parser.add_argument("--output-audio", default="./logs/composite.wav", help="Path to save the composite audio") | |
| # parser.add_argument("--output-plot", default="./logs/melspectrograms.png", help="Path to save the melspectrogram plot") | |
| # args = parser.parse_args() | |
| # os.makedirs("./logs", exist_ok=True) | |
| # main(args.files, args.weights, args.output_audio, args.output_plot) | |
| # Example usage: | |
| # python audio_mixer.py --files audio1.wav audio2.wav audio3.wav audio4.wav --weights 0.4 0.3 0.2 0.1 |