Spaces:
Build error
Build error
| import os | |
| import wave | |
| import numpy as np | |
| import scipy | |
| import ffmpeg | |
| import pyaudio | |
| import threading | |
| import textwrap | |
| import json | |
| import websocket | |
| import uuid | |
| import time | |
| def resample(file: str, sr: int = 16000): | |
| """ | |
| # https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/audio.py#L22 | |
| Open an audio file and read as mono waveform, resampling as necessary, | |
| save the resampled audio | |
| Args: | |
| file (str): The audio file to open | |
| sr (int): The sample rate to resample the audio if necessary | |
| Returns: | |
| resampled_file (str): The resampled audio file | |
| """ | |
| try: | |
| # This launches a subprocess to decode audio while down-mixing and resampling as necessary. | |
| # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. | |
| out, _ = ( | |
| ffmpeg.input(file, threads=0) | |
| .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) | |
| .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) | |
| ) | |
| except ffmpeg.Error as e: | |
| raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e | |
| np_buffer = np.frombuffer(out, dtype=np.int16) | |
| resampled_file = f"{file.split('.')[0]}_resampled.wav" | |
| scipy.io.wavfile.write(resampled_file, sr, np_buffer.astype(np.int16)) | |
| return resampled_file | |
| class Client: | |
| """ | |
| Handles audio recording, streaming, and communication with a server using WebSocket. | |
| """ | |
| INSTANCES = {} | |
| def __init__( | |
| self, host=None, port=None, is_multilingual=False, lang=None, translate=False, model_size="small" | |
| ): | |
| """ | |
| Initializes a Client instance for audio recording and streaming to a server. | |
| If host and port are not provided, the WebSocket connection will not be established. | |
| When translate is True, the task will be set to "translate" instead of "transcribe". | |
| he audio recording starts immediately upon initialization. | |
| Args: | |
| host (str): The hostname or IP address of the server. | |
| port (int): The port number for the WebSocket server. | |
| is_multilingual (bool, optional): Specifies if multilingual transcription is enabled. Default is False. | |
| lang (str, optional): The selected language for transcription when multilingual is disabled. Default is None. | |
| translate (bool, optional): Specifies if the task is translation. Default is False. | |
| """ | |
| self.chunk = 1024 * 3 | |
| self.format = pyaudio.paInt16 | |
| self.channels = 1 | |
| self.rate = 16000 | |
| self.record_seconds = 60000 | |
| self.recording = False | |
| self.multilingual = False | |
| self.language = None | |
| self.task = "transcribe" | |
| self.uid = str(uuid.uuid4()) | |
| self.waiting = False | |
| self.last_response_recieved = None | |
| self.disconnect_if_no_response_for = 15 | |
| self.multilingual = is_multilingual | |
| self.language = lang | |
| self.model_size = model_size | |
| self.server_error = False | |
| if translate: | |
| self.task = "translate" | |
| self.timestamp_offset = 0.0 | |
| self.audio_bytes = None | |
| self.p = pyaudio.PyAudio() | |
| self.stream = self.p.open( | |
| format=self.format, | |
| channels=self.channels, | |
| rate=self.rate, | |
| input=True, | |
| frames_per_buffer=self.chunk, | |
| ) | |
| if host is not None and port is not None: | |
| socket_url = f"ws://{host}:{port}" | |
| self.client_socket = websocket.WebSocketApp( | |
| socket_url, | |
| on_open=lambda ws: self.on_open(ws), | |
| on_message=lambda ws, message: self.on_message(ws, message), | |
| on_error=lambda ws, error: self.on_error(ws, error), | |
| on_close=lambda ws, close_status_code, close_msg: self.on_close( | |
| ws, close_status_code, close_msg | |
| ), | |
| ) | |
| else: | |
| print("[ERROR]: No host or port specified.") | |
| return | |
| Client.INSTANCES[self.uid] = self | |
| # start websocket client in a thread | |
| self.ws_thread = threading.Thread(target=self.client_socket.run_forever) | |
| self.ws_thread.setDaemon(True) | |
| self.ws_thread.start() | |
| self.frames = b"" | |
| print("[INFO]: * recording") | |
| # TTS audio websocket client | |
| socket_url = f"ws://{host}:8888" | |
| self.tts_client_socket = websocket.WebSocketApp( | |
| socket_url, | |
| on_open=lambda ws: self.on_open_tts(ws), | |
| on_message=lambda ws, message: self.on_message_tts(ws, message), | |
| on_error=lambda ws, error: self.on_error_tts(ws, error), | |
| on_close=lambda ws, close_status_code, close_msg: self.on_close_tts( | |
| ws, close_status_code, close_msg | |
| ), | |
| ) | |
| self.tts_ws_thread = threading.Thread(target=self.tts_client_socket.run_forever) | |
| self.tts_ws_thread.setDaemon(True) | |
| self.tts_ws_thread.start() | |
| def on_message(self, ws, message): | |
| """ | |
| Callback function called when a message is received from the server. | |
| It updates various attributes of the client based on the received message, including | |
| recording status, language detection, and server messages. If a disconnect message | |
| is received, it sets the recording status to False. | |
| Args: | |
| ws (websocket.WebSocketApp): The WebSocket client instance. | |
| message (str): The received message from the server. | |
| """ | |
| self.last_response_recieved = time.time() | |
| message = json.loads(message) | |
| if self.uid != message.get("uid"): | |
| print("[ERROR]: invalid client uid") | |
| return | |
| if "status" in message.keys(): | |
| if message["status"] == "WAIT": | |
| self.waiting = True | |
| print( | |
| f"[INFO]:Server is full. Estimated wait time {round(message['message'])} minutes." | |
| ) | |
| elif message["status"] == "ERROR": | |
| print(f"Message from Server: {message['message']}") | |
| self.server_error = True | |
| return | |
| if "message" in message.keys() and message["message"] == "DISCONNECT": | |
| print("[INFO]: Server overtime disconnected.") | |
| self.recording = False | |
| if "message" in message.keys() and message["message"] == "SERVER_READY": | |
| self.recording = True | |
| return | |
| if "language" in message.keys(): | |
| self.language = message.get("language") | |
| lang_prob = message.get("language_prob") | |
| print( | |
| f"[INFO]: Server detected language {self.language} with probability {lang_prob}" | |
| ) | |
| return | |
| if "llm_output" in message.keys(): | |
| print("LLM output: ") | |
| for item in message["llm_output"]: | |
| print(item) | |
| if "segments" not in message.keys(): | |
| return | |
| message = message["segments"] | |
| text = [] | |
| print(message) | |
| if len(message): | |
| for seg in message: | |
| if text and text[-1] == seg["text"]: | |
| # already got it | |
| continue | |
| text.append(seg["text"]) | |
| # keep only last 3 | |
| if len(text) > 3: | |
| text = text[-3:] | |
| wrapper = textwrap.TextWrapper(width=60) | |
| word_list = wrapper.wrap(text="".join(text)) | |
| # Print each line. | |
| # if os.name == "nt": | |
| # os.system("cls") | |
| # else: | |
| # os.system("clear") | |
| for element in word_list: | |
| print(element) | |
| def on_error(self, ws, error): | |
| print(error) | |
| def on_close(self, ws, close_status_code, close_msg): | |
| print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}") | |
| def on_open(self, ws): | |
| """ | |
| Callback function called when the WebSocket connection is successfully opened. | |
| Sends an initial configuration message to the server, including client UID, multilingual mode, | |
| language selection, and task type. | |
| Args: | |
| ws (websocket.WebSocketApp): The WebSocket client instance. | |
| """ | |
| print(self.multilingual, self.language, self.task) | |
| print("[INFO]: Opened connection") | |
| ws.send( | |
| json.dumps( | |
| { | |
| "uid": self.uid, | |
| "multilingual": self.multilingual, | |
| "language": self.language, | |
| "task": self.task, | |
| "model_size": self.model_size, | |
| } | |
| ) | |
| ) | |
| def on_open_tts(self): | |
| pass | |
| def on_message_tts(self, ws, message): | |
| # print(message) | |
| print(type(message)) | |
| self.write_audio_frames_to_file(message.tobytes(), "tts_out.wav", rate=24000) | |
| pass | |
| def on_error_tts(self, ws, error): | |
| print(error) | |
| def on_close_tts(self, ws, close_status_code, close_msg): | |
| print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}") | |
| def bytes_to_float_array(audio_bytes): | |
| """ | |
| Convert audio data from bytes to a NumPy float array. | |
| It assumes that the audio data is in 16-bit PCM format. The audio data is normalized to | |
| have values between -1 and 1. | |
| Args: | |
| audio_bytes (bytes): Audio data in bytes. | |
| Returns: | |
| np.ndarray: A NumPy array containing the audio data as float values normalized between -1 and 1. | |
| """ | |
| raw_data = np.frombuffer(buffer=audio_bytes, dtype=np.int16) | |
| return raw_data.astype(np.float32) / 32768.0 | |
| def send_packet_to_server(self, message): | |
| """ | |
| Send an audio packet to the server using WebSocket. | |
| Args: | |
| message (bytes): The audio data packet in bytes to be sent to the server. | |
| """ | |
| try: | |
| self.client_socket.send(message, websocket.ABNF.OPCODE_BINARY) | |
| except Exception as e: | |
| print(e) | |
| def play_file(self, filename): | |
| """ | |
| Play an audio file and send it to the server for processing. | |
| Reads an audio file, plays it through the audio output, and simultaneously sends | |
| the audio data to the server for processing. It uses PyAudio to create an audio | |
| stream for playback. The audio data is read from the file in chunks, converted to | |
| floating-point format, and sent to the server using WebSocket communication. | |
| This method is typically used when you want to process pre-recorded audio and send it | |
| to the server in real-time. | |
| Args: | |
| filename (str): The path to the audio file to be played and sent to the server. | |
| """ | |
| # read audio and create pyaudio stream | |
| with wave.open(filename, "rb") as wavfile: | |
| self.stream = self.p.open( | |
| format=self.p.get_format_from_width(wavfile.getsampwidth()), | |
| channels=wavfile.getnchannels(), | |
| rate=wavfile.getframerate(), | |
| input=True, | |
| output=True, | |
| frames_per_buffer=self.chunk, | |
| ) | |
| try: | |
| while self.recording: | |
| data = wavfile.readframes(self.chunk) | |
| if data == b"": | |
| break | |
| audio_array = self.bytes_to_float_array(data) | |
| self.send_packet_to_server(audio_array.tobytes()) | |
| self.stream.write(data) | |
| wavfile.close() | |
| assert self.last_response_recieved | |
| while time.time() - self.last_response_recieved < self.disconnect_if_no_response_for: | |
| continue | |
| self.stream.close() | |
| self.close_websocket() | |
| except KeyboardInterrupt: | |
| wavfile.close() | |
| self.stream.stop_stream() | |
| self.stream.close() | |
| self.p.terminate() | |
| self.close_websocket() | |
| print("[INFO]: Keyboard interrupt.") | |
| def close_websocket(self): | |
| """ | |
| Close the WebSocket connection and join the WebSocket thread. | |
| First attempts to close the WebSocket connection using `self.client_socket.close()`. After | |
| closing the connection, it joins the WebSocket thread to ensure proper termination. | |
| """ | |
| try: | |
| self.client_socket.close() | |
| except Exception as e: | |
| print("[ERROR]: Error closing WebSocket:", e) | |
| try: | |
| self.ws_thread.join() | |
| except Exception as e: | |
| print("[ERROR:] Error joining WebSocket thread:", e) | |
| def get_client_socket(self): | |
| """ | |
| Get the WebSocket client socket instance. | |
| Returns: | |
| WebSocketApp: The WebSocket client socket instance currently in use by the client. | |
| """ | |
| return self.client_socket | |
| def write_audio_frames_to_file(self, frames, file_name, rate=None): | |
| """ | |
| Write audio frames to a WAV file. | |
| The WAV file is created or overwritten with the specified name. The audio frames should be | |
| in the correct format and match the specified channel, sample width, and sample rate. | |
| Args: | |
| frames (bytes): The audio frames to be written to the file. | |
| file_name (str): The name of the WAV file to which the frames will be written. | |
| """ | |
| with wave.open(file_name, "wb") as wavfile: | |
| wavfile: wave.Wave_write | |
| wavfile.setnchannels(self.channels) | |
| wavfile.setsampwidth(2) | |
| wavfile.setframerate(self.rate if rate is None else rate) | |
| wavfile.writeframes(frames) | |
| def process_hls_stream(self, hls_url): | |
| """ | |
| Connect to an HLS source, process the audio stream, and send it for transcription. | |
| Args: | |
| hls_url (str): The URL of the HLS stream source. | |
| """ | |
| print("[INFO]: Connecting to HLS stream...") | |
| process = None # Initialize process to None | |
| try: | |
| # Connecting to the HLS stream using ffmpeg-python | |
| process = ( | |
| ffmpeg | |
| .input(hls_url, threads=0) | |
| .output('-', format='s16le', acodec='pcm_s16le', ac=1, ar=self.rate) | |
| .run_async(pipe_stdout=True, pipe_stderr=True) | |
| ) | |
| # Process the stream | |
| while True: | |
| in_bytes = process.stdout.read(self.chunk * 2) # 2 bytes per sample | |
| if not in_bytes: | |
| break | |
| audio_array = self.bytes_to_float_array(in_bytes) | |
| self.send_packet_to_server(audio_array.tobytes()) | |
| except Exception as e: | |
| print(f"[ERROR]: Failed to connect to HLS stream: {e}") | |
| finally: | |
| if process: | |
| process.kill() | |
| print("[INFO]: HLS stream processing finished.") | |
| def record(self, out_file="output_recording.wav"): | |
| """ | |
| Record audio data from the input stream and save it to a WAV file. | |
| Continuously records audio data from the input stream, sends it to the server via a WebSocket | |
| connection, and simultaneously saves it to multiple WAV files in chunks. It stops recording when | |
| the `RECORD_SECONDS` duration is reached or when the `RECORDING` flag is set to `False`. | |
| Audio data is saved in chunks to the "chunks" directory. Each chunk is saved as a separate WAV file. | |
| The recording will continue until the specified duration is reached or until the `RECORDING` flag is set to `False`. | |
| The recording process can be interrupted by sending a KeyboardInterrupt (e.g., pressing Ctrl+C). After recording, | |
| the method combines all the saved audio chunks into the specified `out_file`. | |
| Args: | |
| out_file (str, optional): The name of the output WAV file to save the entire recording. Default is "output_recording.wav". | |
| """ | |
| n_audio_file = 0 | |
| if not os.path.exists("chunks"): | |
| os.makedirs("chunks", exist_ok=True) | |
| try: | |
| for _ in range(0, int(self.rate / self.chunk * self.record_seconds)): | |
| if not self.recording: | |
| break | |
| data = self.stream.read(self.chunk) | |
| self.frames += data | |
| audio_array = Client.bytes_to_float_array(data) | |
| self.send_packet_to_server(audio_array.tobytes()) | |
| # save frames if more than a minute | |
| if len(self.frames) > 60 * self.rate: | |
| t = threading.Thread( | |
| target=self.write_audio_frames_to_file, | |
| args=( | |
| self.frames[:], | |
| f"chunks/{n_audio_file}.wav", | |
| ), | |
| ) | |
| t.start() | |
| n_audio_file += 1 | |
| self.frames = b"" | |
| except KeyboardInterrupt: | |
| if len(self.frames): | |
| self.write_audio_frames_to_file( | |
| self.frames[:], f"chunks/{n_audio_file}.wav" | |
| ) | |
| n_audio_file += 1 | |
| self.stream.stop_stream() | |
| self.stream.close() | |
| self.p.terminate() | |
| self.close_websocket() | |
| self.write_output_recording(n_audio_file, out_file) | |
| def write_output_recording(self, n_audio_file, out_file): | |
| """ | |
| Combine and save recorded audio chunks into a single WAV file. | |
| The individual audio chunk files are expected to be located in the "chunks" directory. Reads each chunk | |
| file, appends its audio data to the final recording, and then deletes the chunk file. After combining | |
| and saving, the final recording is stored in the specified `out_file`. | |
| Args: | |
| n_audio_file (int): The number of audio chunk files to combine. | |
| out_file (str): The name of the output WAV file to save the final recording. | |
| """ | |
| input_files = [ | |
| f"chunks/{i}.wav" | |
| for i in range(n_audio_file) | |
| if os.path.exists(f"chunks/{i}.wav") | |
| ] | |
| with wave.open(out_file, "wb") as wavfile: | |
| wavfile: wave.Wave_write | |
| wavfile.setnchannels(self.channels) | |
| wavfile.setsampwidth(2) | |
| wavfile.setframerate(self.rate) | |
| for in_file in input_files: | |
| with wave.open(in_file, "rb") as wav_in: | |
| while True: | |
| data = wav_in.readframes(self.chunk) | |
| if data == b"": | |
| break | |
| wavfile.writeframes(data) | |
| # remove this file | |
| os.remove(in_file) | |
| wavfile.close() | |
| class TranscriptionClient: | |
| """ | |
| Client for handling audio transcription tasks via a WebSocket connection. | |
| Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used | |
| to send audio data for transcription to a server and receive transcribed text segments. | |
| Args: | |
| host (str): The hostname or IP address of the server. | |
| port (int): The port number to connect to on the server. | |
| is_multilingual (bool, optional): Indicates whether the transcription should support multiple languages (default is False). | |
| lang (str, optional): The primary language for transcription (used if `is_multilingual` is False). Default is None, which defaults to English ('en'). | |
| translate (bool, optional): Indicates whether translation tasks are required (default is False). | |
| Attributes: | |
| client (Client): An instance of the underlying Client class responsible for handling the WebSocket connection. | |
| Example: | |
| To create a TranscriptionClient and start transcription on microphone audio: | |
| ```python | |
| transcription_client = TranscriptionClient(host="localhost", port=9090, is_multilingual=True) | |
| transcription_client() | |
| ``` | |
| """ | |
| def __init__(self, host, port, is_multilingual=False, lang=None, translate=False, model_size="small"): | |
| self.client = Client(host, port, is_multilingual, lang, translate, model_size) | |
| def __call__(self, audio=None, hls_url=None): | |
| """ | |
| Start the transcription process. | |
| Initiates the transcription process by connecting to the server via a WebSocket. It waits for the server | |
| to be ready to receive audio data and then sends audio for transcription. If an audio file is provided, it | |
| will be played and streamed to the server; otherwise, it will perform live recording. | |
| Args: | |
| audio (str, optional): Path to an audio file for transcription. Default is None, which triggers live recording. | |
| """ | |
| print("[INFO]: Waiting for server ready ...") | |
| while not self.client.recording: | |
| if self.client.waiting or self.client.server_error: | |
| self.client.close_websocket() | |
| return | |
| print("[INFO]: Server Ready!") | |
| if hls_url is not None: | |
| self.client.process_hls_stream(hls_url) | |
| elif audio is not None: | |
| resampled_file = resample(audio) | |
| self.client.play_file(resampled_file) | |
| else: | |
| self.client.record() | |