Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # Copyright 2025 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import google.generativeai as genai | |
| import os | |
| import struct | |
| import re | |
| import logging | |
| from cache import cache | |
| # Add these imports for MP3 conversion | |
| from pydub import AudioSegment | |
| import io | |
| # --- Constants --- | |
| GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") | |
| GENERATE_SPEECH = os.environ.get("GENERATE_SPEECH", "false").lower() == "true" | |
| TTS_MODEL = "gemini-2.5-flash-preview-tts" | |
| DEFAULT_RAW_AUDIO_MIME = "audio/L16;rate=24000" | |
| # --- Configuration --- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| class TTSGenerationError(Exception): | |
| """Custom exception for TTS generation failures.""" | |
| pass | |
| # --- Helper functions for audio processing --- | |
| def parse_audio_mime_type(mime_type: str) -> dict[str, int | None]: | |
| """ | |
| Parses bits per sample and rate from an audio MIME type string. | |
| e.g., "audio/L16;rate=24000" -> {"bits_per_sample": 16, "rate": 24000} | |
| """ | |
| bits_per_sample = 16 # Default | |
| rate = 24000 # Default | |
| parts = mime_type.split(";") | |
| for param in parts: | |
| param = param.strip().lower() | |
| if param.startswith("rate="): | |
| try: | |
| rate_str = param.split("=", 1)[1] | |
| rate = int(rate_str) | |
| except (ValueError, IndexError): | |
| pass # Keep default if parsing fails | |
| elif re.match(r"audio/l\d+", param): # Matches audio/L<digits> | |
| try: | |
| bits_str = param.split("l",1)[1] | |
| bits_per_sample = int(bits_str) | |
| except (ValueError, IndexError): | |
| pass # Keep default | |
| return {"bits_per_sample": bits_per_sample, "rate": rate} | |
| def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes: | |
| """ | |
| Generates a WAV file header for the given raw audio data and parameters. | |
| Assumes mono audio. | |
| """ | |
| parameters = parse_audio_mime_type(mime_type) | |
| bits_per_sample = parameters["bits_per_sample"] | |
| sample_rate = parameters["rate"] | |
| num_channels = 1 # Mono | |
| data_size = len(audio_data) | |
| bytes_per_sample = bits_per_sample // 8 | |
| block_align = num_channels * bytes_per_sample | |
| byte_rate = sample_rate * block_align | |
| chunk_size = 36 + data_size | |
| header = struct.pack( | |
| "<4sI4s4sIHHIIHH4sI", | |
| b"RIFF", chunk_size, b"WAVE", b"fmt ", | |
| 16, 1, num_channels, sample_rate, byte_rate, block_align, | |
| bits_per_sample, b"data", data_size | |
| ) | |
| return header + audio_data | |
| # --- End of helper functions --- | |
| def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[bytes, str]: | |
| """ | |
| Synthesizes English text using the Gemini API via the google-genai library. | |
| Returns a tuple: (processed_audio_data_bytes, final_mime_type). | |
| Raises TTSGenerationError on failure. | |
| """ | |
| if not GENERATE_SPEECH: | |
| # This should ideally not be hit if the logic outside this function is correct, | |
| # but as a safeguard, we raise an error. | |
| raise TTSGenerationError( | |
| "GENERATE_SPEECH is not set. Please set it in your environment variables to generate speech." | |
| ) | |
| try: | |
| model = genai.GenerativeModel(TTS_MODEL) | |
| generation_config = { | |
| "response_modalities": ["AUDIO"], | |
| "speech_config": { | |
| "voice_config": { | |
| "prebuilt_voice_config": { | |
| "voice_name": gemini_voice_name | |
| } | |
| } | |
| } | |
| } | |
| response = model.generate_content( | |
| contents=[text], | |
| generation_config=generation_config, | |
| ) | |
| audio_part = response.candidates[0].content.parts[0] | |
| audio_data_bytes = audio_part.inline_data.data | |
| final_mime_type = audio_part.inline_data.mime_type | |
| except Exception as e: | |
| error_message = f"An unexpected error occurred with google-genai: {e}" | |
| logging.error(error_message) | |
| raise TTSGenerationError(error_message) from e | |
| if not audio_data_bytes: | |
| error_message = "No audio data was successfully retrieved or decoded." | |
| logging.error(error_message) | |
| raise TTSGenerationError(error_message) | |
| # --- Audio processing --- | |
| if final_mime_type: | |
| final_mime_type_lower = final_mime_type.lower() | |
| needs_wav_conversion = any(p in final_mime_type_lower for p in ("audio/l16", "audio/l24", "audio/l8")) or \ | |
| not final_mime_type_lower.startswith(("audio/wav", "audio/mpeg", "audio/ogg", "audio/opus")) | |
| if needs_wav_conversion: | |
| processed_audio_data = convert_to_wav(audio_data_bytes, final_mime_type) | |
| processed_audio_mime = "audio/wav" | |
| else: | |
| processed_audio_data = audio_data_bytes | |
| processed_audio_mime = final_mime_type | |
| else: | |
| logging.warning("MIME type not determined. Assuming raw audio and attempting WAV conversion (defaulting to %s).", DEFAULT_RAW_AUDIO_MIME) | |
| processed_audio_data = convert_to_wav(audio_data_bytes, DEFAULT_RAW_AUDIO_MIME) | |
| processed_audio_mime = "audio/wav" | |
| # --- MP3 compression --- | |
| if processed_audio_data: | |
| try: | |
| # Load audio into AudioSegment | |
| audio_segment = AudioSegment.from_file(io.BytesIO(processed_audio_data), format="wav") | |
| mp3_buffer = io.BytesIO() | |
| audio_segment.export(mp3_buffer, format="mp3") | |
| mp3_bytes = mp3_buffer.getvalue() | |
| return mp3_bytes, "audio/mpeg" | |
| except Exception as e: | |
| logging.warning("MP3 compression failed: %s. Falling back to WAV.", e) | |
| # Fallback to WAV if MP3 conversion fails | |
| return processed_audio_data, processed_audio_mime | |
| else: | |
| error_message = "Audio processing failed." | |
| logging.error(error_message) | |
| raise TTSGenerationError(error_message) | |
| # Always create the memoized function first, so we can access its .key() method | |
| _memoized_tts_func = cache.memoize()(_synthesize_gemini_tts_impl) | |
| if GENERATE_SPEECH: | |
| def synthesize_gemini_tts_with_error_handling(*args, **kwargs) -> tuple[bytes | None, str | None]: | |
| """ | |
| A wrapper for the memoized TTS function that catches errors and returns (None, None). | |
| This makes the audio generation more resilient to individual failures. | |
| """ | |
| try: | |
| # Attempt to get the audio from the cache or by generating it. | |
| return _memoized_tts_func(*args, **kwargs) | |
| except TTSGenerationError as e: | |
| # If generation fails, log the error and return None, None. | |
| logging.error("Handled TTS Generation Error: %s. Continuing without audio for this segment.", e) | |
| return None, None | |
| synthesize_gemini_tts = synthesize_gemini_tts_with_error_handling | |
| else: | |
| # When not generating speech, create a read-only function that only | |
| # checks the cache and does not generate new audio. | |
| def read_only_synthesize_gemini_tts(*args, **kwargs): | |
| """ | |
| Checks cache for a result, but never calls the underlying TTS function. | |
| This is a 'read-only' memoization check. | |
| """ | |
| # Generate the cache key using the memoized function's key method. | |
| key = _memoized_tts_func.__cache_key__(*args, **kwargs) | |
| # Check the cache directly using the generated key. | |
| _sentinel = object() | |
| result = cache.get(key, default=_sentinel) | |
| if result is not _sentinel: | |
| return result # Cache hit | |
| # Cache miss | |
| logging.info("GENERATE_SPEECH is false and no cached result found for key: %s", key) | |
| return None, None | |
| synthesize_gemini_tts = read_only_synthesize_gemini_tts |