File size: 3,978 Bytes
10d7861
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import warnings
from pathlib import Path
from typing import Literal

import numpy as np

try:  # absolute imports when installed
    from trackio.media.utils import check_ffmpeg_installed, check_path
except ImportError:  # relative imports for local execution on Spaces
    from media.utils import check_ffmpeg_installed, check_path

# Try to import pydub, but make it optional
try:
    from pydub import AudioSegment

    PYDUB_AVAILABLE = True
except ImportError:
    PYDUB_AVAILABLE = False
    AudioSegment = None

SUPPORTED_FORMATS = ["wav", "mp3"]
AudioFormatType = Literal["wav", "mp3"]


def ensure_int16_pcm(data: np.ndarray) -> np.ndarray:
    """
    Convert input audio array to contiguous int16 PCM.
    Peak normalization is applied to floating inputs.
    """
    arr = np.asarray(data)
    if arr.ndim not in (1, 2):
        raise ValueError("Audio data must be 1D (mono) or 2D ([samples, channels])")

    if arr.dtype != np.int16:
        warnings.warn(
            f"Converting {arr.dtype} audio to int16 PCM; pass int16 to avoid conversion.",
            stacklevel=2,
        )

    arr = np.nan_to_num(arr, copy=False)

    # Floating types: normalize to peak 1.0, then scale to int16
    if np.issubdtype(arr.dtype, np.floating):
        max_abs = float(np.max(np.abs(arr))) if arr.size else 0.0
        if max_abs > 0.0:
            arr = arr / max_abs
        out = (arr * 32767.0).clip(-32768, 32767).astype(np.int16, copy=False)
        return np.ascontiguousarray(out)

    converters: dict[np.dtype, callable] = {
        np.dtype(np.int16): lambda a: a,
        np.dtype(np.int32): lambda a: (
            (a.astype(np.int32) // 65536).astype(np.int16, copy=False)
        ),
        np.dtype(np.uint16): lambda a: (
            (a.astype(np.int32) - 32768).astype(np.int16, copy=False)
        ),
        np.dtype(np.uint8): lambda a: (
            (a.astype(np.int32) * 257 - 32768).astype(np.int16, copy=False)
        ),
        np.dtype(np.int8): lambda a: (
            (a.astype(np.int32) * 256).astype(np.int16, copy=False)
        ),
    }

    conv = converters.get(arr.dtype)
    if conv is not None:
        out = conv(arr)
        return np.ascontiguousarray(out)
    raise TypeError(f"Unsupported audio dtype: {arr.dtype}")


def write_audio(
    data: np.ndarray,
    sample_rate: int,
    filename: str | Path,
    format: AudioFormatType = "wav",
) -> None:
    if not isinstance(sample_rate, int) or sample_rate <= 0:
        raise ValueError(f"Invalid sample_rate: {sample_rate}")
    if format not in SUPPORTED_FORMATS:
        raise ValueError(
            f"Unsupported format: {format}. Supported: {SUPPORTED_FORMATS}"
        )

    check_path(filename)

    pcm = ensure_int16_pcm(data)

    # If pydub is missing, allow WAV fallback, otherwise require pydub
    if not PYDUB_AVAILABLE:
        if format == "wav":
            write_wav_simple(filename, pcm, sample_rate)
            return
        raise ImportError(
            "pydub is required for non-WAV formats. Install with: pip install pydub"
        )

    if format != "wav":
        check_ffmpeg_installed()

    channels = 1 if pcm.ndim == 1 else pcm.shape[1]
    audio = AudioSegment(
        pcm.tobytes(),
        frame_rate=sample_rate,
        sample_width=2,  # int16
        channels=channels,
    )

    file = audio.export(str(filename), format=format)
    file.close()


def write_wav_simple(
    file_path: str | Path, data: np.ndarray, sample_rate: int = 44100
) -> None:
    """Fallback for basic WAV export when pydub is not available."""
    import wave

    pcm = ensure_int16_pcm(data)
    if pcm.ndim > 2:
        raise ValueError("Audio data must be 1D (mono) or 2D (stereo)")

    with wave.open(str(file_path), "wb") as wav_file:
        wav_file.setnchannels(1 if pcm.ndim == 1 else pcm.shape[1])
        wav_file.setsampwidth(2)  # 16-bit = 2 bytes
        wav_file.setframerate(sample_rate)
        wav_file.writeframes(pcm.tobytes())