diff --git a/STT/lightning_whisper_mlx_handler.py b/STT/lightning_whisper_mlx_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..acf6c1af2a5d2b8936c2574506ed33cdeee6dc4b --- /dev/null +++ b/STT/lightning_whisper_mlx_handler.py @@ -0,0 +1,56 @@ +import logging +from time import perf_counter +from baseHandler import BaseHandler +from lightning_whisper_mlx import LightningWhisperMLX +import numpy as np +from rich.console import Console + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + + +class LightningWhisperSTTHandler(BaseHandler): + """ + Handles the Speech To Text generation using a Whisper model. + """ + + def setup( + self, + model_name="distil-whisper/distil-large-v3", + device="cuda", + torch_dtype="float16", + compile_mode=None, + gen_kwargs={}, + ): + self.device = device + self.model = LightningWhisperMLX( + model="distil-medium.en", batch_size=12, quant=None + ) + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + + # 2 warmup steps for no compile or compile mode with CUDA graphs capture + n_steps = 1 + dummy_input = np.array([0] * 512) + + for _ in range(n_steps): + _ = self.model.transcribe(dummy_input)["text"].strip() + + def process(self, spoken_prompt): + logger.debug("infering whisper...") + + global pipeline_start + pipeline_start = perf_counter() + + pred_text = self.model.transcribe(spoken_prompt)["text"].strip() + + logger.debug("finished whisper inference") + console.print(f"[yellow]USER: {pred_text}") + + yield pred_text diff --git a/TTS/melotts.py b/TTS/melotts.py new file mode 100644 index 0000000000000000000000000000000000000000..eef660105611c99581bbb07e32184068f38cb54e --- /dev/null +++ b/TTS/melotts.py @@ -0,0 +1,53 @@ +from MeloTTS.melo.api import TTS +import logging +from baseHandler import BaseHandler +import librosa +import numpy as np +from rich.console import Console +import torch + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + + +class MeloTTSHandler(BaseHandler): + def setup( + self, + should_listen, + device="mps", + language="EN_NEWEST", + blocksize=512, + ): + self.should_listen = should_listen + self.device = device + self.model = TTS(language=language, device=device) + self.speaker_id = self.model.hps.data.spk2id["EN-Newest"] + self.blocksize = blocksize + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + _ = self.model.tts_to_file("text", self.speaker_id, quiet=True) + + def process(self, llm_sentence): + console.print(f"[green]ASSISTANT: {llm_sentence}") + if self.device == "mps": + torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete. + + audio_chunk = self.model.tts_to_file(llm_sentence, self.speaker_id, quiet=True) + if len(audio_chunk) == 0: + self.should_listen.set() + return + audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) + audio_chunk = (audio_chunk * 32768).astype(np.int16) + for i in range(0, len(audio_chunk), self.blocksize): + yield np.pad( + audio_chunk[i : i + self.blocksize], + (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), + ) + + self.should_listen.set() diff --git a/baseHandler.py b/baseHandler.py new file mode 100644 index 0000000000000000000000000000000000000000..6f5efa80adb473cff0b50336ddbc744fe50ba27c --- /dev/null +++ b/baseHandler.py @@ -0,0 +1,51 @@ +from time import perf_counter +import logging + +logger = logging.getLogger(__name__) + + +class BaseHandler: + """ + Base class for pipeline parts. Each part of the pipeline has an input and an output queue. + The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part. + To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue. + Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue. + The cleanup method handles stopping the handler, and b"END" is placed in the output queue. + """ + + def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}): + self.stop_event = stop_event + self.queue_in = queue_in + self.queue_out = queue_out + self.setup(*setup_args, **setup_kwargs) + self._times = [] + + def setup(self): + pass + + def process(self): + raise NotImplementedError + + def run(self): + while not self.stop_event.is_set(): + input = self.queue_in.get() + if isinstance(input, bytes) and input == b"END": + # sentinelle signal to avoid queue deadlock + logger.debug("Stopping thread") + break + start_time = perf_counter() + for output in self.process(input): + self._times.append(perf_counter() - start_time) + logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s") + self.queue_out.put(output) + start_time = perf_counter() + + self.cleanup() + self.queue_out.put(b"END") + + @property + def last_time(self): + return self._times[-1] + + def cleanup(self): + pass diff --git a/listen_and_play.py b/listen_and_play.py index 675f78b84a39c41f41f05b51b141fde87f105dfc..8fc0dfec09eeab4bbb465078031bd4ec9a9a82d5 100644 --- a/listen_and_play.py +++ b/listen_and_play.py @@ -1,5 +1,4 @@ import socket -import numpy as np import threading from queue import Queue from dataclasses import dataclass, field @@ -9,41 +8,25 @@ from transformers import HfArgumentParser @dataclass class ListenAndPlayArguments: - send_rate: int = field( - default=16000, - metadata={ - "help": "In Hz. Default is 16000." - } - ) - recv_rate: int = field( - default=44100, - metadata={ - "help": "In Hz. Default is 44100." - } - ) + send_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."}) + recv_rate: int = field(default=44100, metadata={"help": "In Hz. Default is 44100."}) list_play_chunk_size: int = field( default=1024, - metadata={ - "help": "The size of data chunks (in bytes). Default is 1024." - } + metadata={"help": "The size of data chunks (in bytes). Default is 1024."}, ) host: str = field( default="localhost", metadata={ "help": "The hostname or IP address for listening and playing. Default is 'localhost'." - } + }, ) send_port: int = field( default=12345, - metadata={ - "help": "The network port for sending data. Default is 12345." - } + metadata={"help": "The network port for sending data. Default is 12345."}, ) recv_port: int = field( default=12346, - metadata={ - "help": "The network port for receiving data. Default is 12346." - } + metadata={"help": "The network port for receiving data. Default is 12346."}, ) @@ -55,7 +38,6 @@ def listen_and_play( send_port=12345, recv_port=12346, ): - send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) send_socket.connect((host, send_port)) @@ -68,13 +50,13 @@ def listen_and_play( recv_queue = Queue() send_queue = Queue() - def callback_recv(outdata, frames, time, status): + def callback_recv(outdata, frames, time, status): if not recv_queue.empty(): data = recv_queue.get() - outdata[:len(data)] = data - outdata[len(data):] = b'\x00' * (len(outdata) - len(data)) + outdata[: len(data)] = data + outdata[len(data) :] = b"\x00" * (len(outdata) - len(data)) else: - outdata[:] = b'\x00' * len(outdata) + outdata[:] = b"\x00" * len(outdata) def callback_send(indata, frames, time, status): if recv_queue.empty(): @@ -85,11 +67,10 @@ def listen_and_play( while not stop_event.is_set(): data = send_queue.get() send_socket.sendall(data) - - def recv(stop_event, recv_queue): + def recv(stop_event, recv_queue): def receive_full_chunk(conn, chunk_size): - data = b'' + data = b"" while len(data) < chunk_size: packet = conn.recv(chunk_size - len(data)) if not packet: @@ -98,13 +79,25 @@ def listen_and_play( return data while not stop_event.is_set(): - data = receive_full_chunk(recv_socket, list_play_chunk_size * 2) + data = receive_full_chunk(recv_socket, list_play_chunk_size * 2) if data: recv_queue.put(data) - try: - send_stream = sd.RawInputStream(samplerate=send_rate, channels=1, dtype='int16', blocksize=list_play_chunk_size, callback=callback_send) - recv_stream = sd.RawOutputStream(samplerate=recv_rate, channels=1, dtype='int16', blocksize=list_play_chunk_size, callback=callback_recv) + try: + send_stream = sd.RawInputStream( + samplerate=send_rate, + channels=1, + dtype="int16", + blocksize=list_play_chunk_size, + callback=callback_send, + ) + recv_stream = sd.RawOutputStream( + samplerate=recv_rate, + channels=1, + dtype="int16", + blocksize=list_play_chunk_size, + callback=callback_recv, + ) threading.Thread(target=send_stream.start).start() threading.Thread(target=recv_stream.start).start() @@ -112,7 +105,7 @@ def listen_and_play( send_thread.start() recv_thread = threading.Thread(target=recv, args=(stop_event, recv_queue)) recv_thread.start() - + input("Press Enter to stop...") except KeyboardInterrupt: @@ -129,6 +122,5 @@ def listen_and_play( if __name__ == "__main__": parser = HfArgumentParser((ListenAndPlayArguments,)) - listen_and_play_kwargs, = parser.parse_args_into_dataclasses() + (listen_and_play_kwargs,) = parser.parse_args_into_dataclasses() listen_and_play(**vars(listen_and_play_kwargs)) - diff --git a/local_audio_streamer.py b/local_audio_streamer.py new file mode 100644 index 0000000000000000000000000000000000000000..5203785aa6094e692c511540a679d239e03bae33 --- /dev/null +++ b/local_audio_streamer.py @@ -0,0 +1,38 @@ +import threading +import sounddevice as sd +import numpy as np + +import time + + +class LocalAudioStreamer: + def __init__( + self, + input_queue, + output_queue, + list_play_chunk_size=512, + ): + self.list_play_chunk_size = list_play_chunk_size + + self.stop_event = threading.Event() + self.input_queue = input_queue + self.output_queue = output_queue + + def run(self): + def callback(indata, outdata, frames, time, status): + if self.output_queue.empty(): + self.input_queue.put(indata.copy()) + outdata[:] = 0 * outdata + else: + outdata[:] = self.output_queue.get()[:, np.newaxis] + + with sd.Stream( + samplerate=16000, + dtype="int16", + channels=1, + callback=callback, + blocksize=self.list_play_chunk_size, + ): + while not self.stop_event.is_set(): + time.sleep(0.001) + print("Stopping recording") diff --git a/requirements.txt b/requirements.txt index e04ebea181f5323d7ce91c7a916c1c34c6d45ef8..b928d18a342987752c2c2f1851939562abdeadf3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ nltk==3.8.1 parler_tts @ git+https://github.com/huggingface/parler-tts.git torch==2.4.0 sounddevice==0.5.0 +lightning-whisper-mlx==0.0.10 \ No newline at end of file diff --git a/s2s_pipeline.py b/s2s_pipeline.py index b367bb1ca929fae4339f00dca971c4a56ef19835..8c36a5d3f47c351a47c98283783c18c9f110ca16 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -3,17 +3,19 @@ import os import socket import sys import threading -from collections import deque from copy import copy from dataclasses import dataclass, field from pathlib import Path from queue import Queue from threading import Event, Thread from time import perf_counter +from typing import Optional +from TTS.melotts import MeloTTSHandler +from baseHandler import BaseHandler +from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler import numpy as np import torch -import nltk from nltk.tokenize import sent_tokenize from rich.console import Console from transformers import ( @@ -23,19 +25,13 @@ from transformers import ( AutoTokenizer, HfArgumentParser, pipeline, - TextIteratorStreamer + TextIteratorStreamer, ) +from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer +import librosa -from parler_tts import ( - ParlerTTSForConditionalGeneration, - ParlerTTSStreamer -) - -from utils import ( - VADIterator, - int2float, - next_power_of_2 -) +from local_audio_streamer import LocalAudioStreamer +from utils import VADIterator, int2float, next_power_of_2 # Ensure that the necessary NLTK resources are available try: @@ -46,10 +42,10 @@ except (LookupError, OSError): # caching allows ~50% compilation time reduction # see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma CURRENT_DIR = Path(__file__).resolve().parent -os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") -torch._inductor.config.fx_graph_cache = True -# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS -torch._dynamo.config.cache_size_limit = 15 +os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") +# torch._inductor.config.fx_graph_cache = True +# # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS +# torch._dynamo.config.cache_size_limit = 15 console = Console() @@ -57,11 +53,21 @@ console = Console() @dataclass class ModuleArguments: + device: Optional[str] = field( + default=None, + metadata={"help": "If specified, overrides the device for all handlers."}, + ) + mode: Optional[str] = field( + default="local", + metadata={ + "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'." + }, + ) log_level: str = field( default="info", metadata={ "help": "Provide logging level. Example --log_level debug, default=warning." - } + }, ) @@ -87,73 +93,26 @@ class ThreadManager: thread.join() -class BaseHandler: - """ - Base class for pipeline parts. Each part of the pipeline has an input and an output queue. - The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part. - To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue. - Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue. - The cleanup method handles stopping the handler, and b"END" is placed in the output queue. - """ - - def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}): - self.stop_event = stop_event - self.queue_in = queue_in - self.queue_out = queue_out - self.setup(*setup_args, **setup_kwargs) - self._times = [] - - def setup(self): - pass - - def process(self): - raise NotImplementedError - - def run(self): - while not self.stop_event.is_set(): - input = self.queue_in.get() - if isinstance(input, bytes) and input == b'END': - # sentinelle signal to avoid queue deadlock - logger.debug("Stopping thread") - break - start_time = perf_counter() - for output in self.process(input): - self._times.append(perf_counter() - start_time) - logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s") - self.queue_out.put(output) - start_time = perf_counter() - - self.cleanup() - self.queue_out.put(b'END') - - @property - def last_time(self): - return self._times[-1] - - def cleanup(self): - pass - - @dataclass class SocketReceiverArguments: recv_host: str = field( default="localhost", metadata={ "help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all " - "available interfaces on the host machine." - } + "available interfaces on the host machine." + }, ) recv_port: int = field( default=12345, metadata={ "help": "The port number on which the socket server listens. Default is 12346." - } + }, ) chunk_size: int = field( default=1024, metadata={ "help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes." - } + }, ) @@ -163,28 +122,28 @@ class SocketReceiver: """ def __init__( - self, + self, stop_event, queue_out, should_listen, - host='0.0.0.0', + host="0.0.0.0", port=12345, - chunk_size=1024 - ): + chunk_size=1024, + ): self.stop_event = stop_event self.queue_out = queue_out self.should_listen = should_listen - self.chunk_size=chunk_size + self.chunk_size = chunk_size self.host = host self.port = port def receive_full_chunk(self, conn, chunk_size): - data = b'' + data = b"" while len(data) < chunk_size: packet = conn.recv(chunk_size - len(data)) if not packet: # connection closed - return None + return None data += packet return data @@ -193,7 +152,7 @@ class SocketReceiver: self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.bind((self.host, self.port)) self.socket.listen(1) - logger.info('Receiver waiting to be connected...') + logger.info("Receiver waiting to be connected...") self.conn, _ = self.socket.accept() logger.info("receiver connected") @@ -202,7 +161,7 @@ class SocketReceiver: audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size) if audio_chunk is None: # connection closed - self.queue_out.put(b'END') + self.queue_out.put(b"END") break if self.should_listen.is_set(): self.queue_out.put(audio_chunk) @@ -216,48 +175,41 @@ class SocketSenderArguments: default="localhost", metadata={ "help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all " - "available interfaces on the host machine." - } + "available interfaces on the host machine." + }, ) send_port: int = field( default=12346, metadata={ "help": "The port number on which the socket server listens. Default is 12346." - } + }, ) - + class SocketSender: """ Handles sending generated audio packets to the clients. """ - def __init__( - self, - stop_event, - queue_in, - host='0.0.0.0', - port=12346 - ): + def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346): self.stop_event = stop_event self.queue_in = queue_in self.host = host self.port = port - def run(self): self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.bind((self.host, self.port)) self.socket.listen(1) - logger.info('Sender waiting to be connected...') + logger.info("Sender waiting to be connected...") self.conn, _ = self.socket.accept() logger.info("sender connected") while not self.stop_event.is_set(): audio_chunk = self.queue_in.get() self.conn.sendall(audio_chunk) - if isinstance(audio_chunk, bytes) and audio_chunk == b'END': + if isinstance(audio_chunk, bytes) and audio_chunk == b"END": break self.conn.close() logger.info("Sender closed") @@ -269,37 +221,37 @@ class VADHandlerArguments: default=0.3, metadata={ "help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection." - } + }, ) sample_rate: int = field( default=16000, metadata={ "help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio." - } + }, ) min_silence_ms: int = field( default=250, metadata={ - "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 1000 ms." - } + "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms." + }, ) min_speech_ms: int = field( - default=750, + default=500, metadata={ "help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms." - } + }, ) max_speech_ms: float = field( - default=float('inf'), + default=float("inf"), metadata={ "help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments." - } + }, ) speech_pad_ms: int = field( - default=30, + default=250, metadata={ - "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 30 ms." - } + "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms." + }, ) @@ -310,22 +262,21 @@ class VADHandler(BaseHandler): """ def setup( - self, - should_listen, - thresh=0.3, - sample_rate=16000, - min_silence_ms=1000, - min_speech_ms=500, - max_speech_ms=float('inf'), - speech_pad_ms=30, - - ): + self, + should_listen, + thresh=0.3, + sample_rate=16000, + min_silence_ms=1000, + min_speech_ms=500, + max_speech_ms=float("inf"), + speech_pad_ms=30, + ): self.should_listen = should_listen self.sample_rate = sample_rate self.min_silence_ms = min_silence_ms self.min_speech_ms = min_speech_ms self.max_speech_ms = max_speech_ms - self.model, _ = torch.hub.load('snakers4/silero-vad', 'silero_vad') + self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad") self.iterator = VADIterator( self.model, threshold=thresh, @@ -343,7 +294,9 @@ class VADHandler(BaseHandler): array = torch.cat(vad_output).cpu().numpy() duration_ms = len(array) / self.sample_rate * 1000 if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms: - logger.debug(f"audio input of duration: {len(array) / self.sample_rate}s, skipping") + logger.debug( + f"audio input of duration: {len(array) / self.sample_rate}s, skipping" + ) else: self.should_listen.clear() logger.debug("Stop listening") @@ -356,56 +309,56 @@ class WhisperSTTHandlerArguments: default="distil-whisper/distil-large-v3", metadata={ "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'." - } + }, ) stt_device: str = field( default="cuda", metadata={ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." - } + }, ) stt_torch_dtype: str = field( default="float16", metadata={ "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." - } + }, ) stt_compile_mode: str = field( default=None, metadata={ "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" - } + }, ) stt_gen_max_new_tokens: int = field( default=128, metadata={ "help": "The maximum number of new tokens to generate. Default is 128." - } + }, ) stt_gen_num_beams: int = field( default=1, metadata={ "help": "The number of beams for beam search. Default is 1, implying greedy decoding." - } + }, ) stt_gen_return_timestamps: bool = field( default=False, metadata={ "help": "Whether to return timestamps with transcriptions. Default is False." - } - ) - stt_gen_task: str = field( - default="transcribe", - metadata={ - "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'." - } - ) - stt_gen_language: str = field( - default="en", - metadata={ - "help": "The language of the speech to transcribe. Default is 'en' for English." - } + }, ) + # stt_gen_task: str = field( + # default="transcribe", + # metadata={ + # "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'." + # }, + # ) + # stt_gen_language: str = field( + # default="en", + # metadata={ + # "help": "The language of the speech to transcribe. Default is 'en' for English." + # }, + # ) class WhisperSTTHandler(BaseHandler): @@ -414,16 +367,16 @@ class WhisperSTTHandler(BaseHandler): """ def setup( - self, - model_name="distil-whisper/distil-large-v3", - device="cuda", - torch_dtype="float16", - compile_mode=None, - gen_kwargs={} - ): + self, + model_name="distil-whisper/distil-large-v3", + device="cuda", + torch_dtype="float16", + compile_mode=None, + gen_kwargs={}, + ): self.device = device self.torch_dtype = getattr(torch, torch_dtype) - self.compile_mode=compile_mode + self.compile_mode = compile_mode self.gen_kwargs = gen_kwargs self.processor = AutoProcessor.from_pretrained(model_name) @@ -431,13 +384,15 @@ class WhisperSTTHandler(BaseHandler): model_name, torch_dtype=self.torch_dtype, ).to(device) - + # compile if self.compile_mode: self.model.generation_config.cache_implementation = "static" - self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True) + self.model.forward = torch.compile( + self.model.forward, mode=self.compile_mode, fullgraph=True + ) self.warmup() - + def prepare_model_inputs(self, spoken_prompt): input_features = self.processor( spoken_prompt, sampling_rate=16000, return_tensors="pt" @@ -445,39 +400,44 @@ class WhisperSTTHandler(BaseHandler): input_features = input_features.to(self.device, dtype=self.torch_dtype) return input_features - + def warmup(self): logger.info(f"Warming up {self.__class__.__name__}") - # 2 warmup steps for no compile or compile mode with CUDA graphs capture + # 2 warmup steps for no compile or compile mode with CUDA graphs capture n_steps = 1 if self.compile_mode == "default" else 2 dummy_input = torch.randn( - (1, self.model.config.num_mel_bins, 3000), + (1, self.model.config.num_mel_bins, 3000), dtype=self.torch_dtype, - device=self.device - ) + device=self.device, + ) if self.compile_mode not in (None, "default"): # generating more tokens than previously will trigger CUDA graphs capture # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation warmup_gen_kwargs = { "min_new_tokens": self.gen_kwargs["max_new_tokens"], "max_new_tokens": self.gen_kwargs["max_new_tokens"], - **self.gen_kwargs + **self.gen_kwargs, } else: warmup_gen_kwargs = self.gen_kwargs - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + if self.device == "cuda": + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() - torch.cuda.synchronize() - start_event.record() for _ in range(n_steps): _ = self.model.generate(dummy_input, **warmup_gen_kwargs) - end_event.record() - torch.cuda.synchronize() - logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") + if self.device == "cuda": + end_event.record() + torch.cuda.synchronize() + + logger.info( + f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" + ) def process(self, spoken_prompt): logger.debug("infering whisper...") @@ -488,9 +448,7 @@ class WhisperSTTHandler(BaseHandler): input_features = self.prepare_model_inputs(spoken_prompt) pred_ids = self.model.generate(input_features, **self.gen_kwargs) pred_text = self.processor.batch_decode( - pred_ids, - skip_special_tokens=True, - decode_with_timestamps=False + pred_ids, skip_special_tokens=True, decode_with_timestamps=False )[0] logger.debug("finished whisper inference") @@ -502,56 +460,64 @@ class WhisperSTTHandler(BaseHandler): @dataclass class LanguageModelHandlerArguments: lm_model_name: str = field( - default="microsoft/Phi-3-mini-4k-instruct", + default="HuggingFaceTB/SmolLM-360M-Instruct", metadata={ "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." - } + }, ) lm_device: str = field( default="cuda", metadata={ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." - } + }, ) lm_torch_dtype: str = field( default="float16", metadata={ "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." - } + }, ) user_role: str = field( default="user", metadata={ "help": "Role assigned to the user in the chat context. Default is 'user'." - } + }, ) init_chat_role: str = field( default=None, metadata={ "help": "Initial role for setting up the chat context. Default is 'system'." - } + }, ) init_chat_prompt: str = field( default="You are a helpful AI assistant.", metadata={ "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" - } + }, ) lm_gen_max_new_tokens: int = field( default=64, - metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 128."} + metadata={ + "help": "Maximum number of new tokens to generate in a single completion. Default is 128." + }, ) lm_gen_temperature: float = field( default=0.0, - metadata={"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."} + metadata={ + "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0." + }, ) lm_gen_do_sample: bool = field( default=False, - metadata={"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."} + metadata={ + "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False." + }, ) chat_size: int = field( default=1, - metadata={"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."} + metadata={ + "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." + }, ) @@ -570,7 +536,7 @@ class Chat: self.buffer.append(item) if len(self.buffer) == 2 * (self.size + 1): self.buffer.pop(0) - self.buffer.pop(0) + self.buffer.pop(0) def init_chat(self, init_chat_message): self.init_chat_message = init_chat_message @@ -584,34 +550,30 @@ class Chat: class LanguageModelHandler(BaseHandler): """ - Handles the language model part. + Handles the language model part. """ def setup( - self, - model_name="microsoft/Phi-3-mini-4k-instruct", - device="cuda", - torch_dtype="float16", - gen_kwargs={}, - user_role="user", - chat_size=1, - init_chat_role=None, - init_chat_prompt="You are a helpful AI assistant.", - ): + self, + model_name="microsoft/Phi-3-mini-4k-instruct", + device="cuda", + torch_dtype="float16", + gen_kwargs={}, + user_role="user", + chat_size=1, + init_chat_role=None, + init_chat_prompt="You are a helpful AI assistant.", + ): self.device = device self.torch_dtype = getattr(torch, torch_dtype) self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True + model_name, torch_dtype=torch_dtype, trust_remote_code=True ).to(device) - self.pipe = pipeline( - "text-generation", - model=self.model, - tokenizer=self.tokenizer, - ) + self.pipe = pipeline( + "text-generation", model=self.model, tokenizer=self.tokenizer, device=device + ) self.streamer = TextIteratorStreamer( self.tokenizer, skip_prompt=True, @@ -620,16 +582,16 @@ class LanguageModelHandler(BaseHandler): self.gen_kwargs = { "streamer": self.streamer, "return_full_text": False, - **gen_kwargs + **gen_kwargs, } self.chat = Chat(chat_size) if init_chat_role: if not init_chat_prompt: - raise ValueError(f"An initial promt needs to be specified when setting init_chat_role.") - self.chat.init_chat( - {"role": init_chat_role, "content": init_chat_prompt} - ) + raise ValueError( + "An initial promt needs to be specified when setting init_chat_role." + ) + self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) self.user_role = user_role self.warmup() @@ -642,47 +604,57 @@ class LanguageModelHandler(BaseHandler): warmup_gen_kwargs = { "min_new_tokens": self.gen_kwargs["max_new_tokens"], "max_new_tokens": self.gen_kwargs["max_new_tokens"], - **self.gen_kwargs + **self.gen_kwargs, } n_steps = 2 - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + if self.device == "cuda": + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() - torch.cuda.synchronize() - start_event.record() for _ in range(n_steps): - thread = Thread(target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs) + thread = Thread( + target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs + ) thread.start() - for _ in self.streamer: - pass - end_event.record() - torch.cuda.synchronize() + for _ in self.streamer: + pass - logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") + if self.device == "cuda": + end_event.record() + torch.cuda.synchronize() + + logger.info( + f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" + ) def process(self, prompt): logger.debug("infering language model...") - self.chat.append( - {"role": self.user_role, "content": prompt} + self.chat.append({"role": self.user_role, "content": prompt}) + thread = Thread( + target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs ) - thread = Thread(target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs) thread.start() + if self.device == "mps": + generated_text = "" + for new_text in self.streamer: + generated_text += new_text + printable_text = generated_text + else: + generated_text, printable_text = "", "" + for new_text in self.streamer: + generated_text += new_text + printable_text += new_text + sentences = sent_tokenize(printable_text) + if len(sentences) > 1: + yield (sentences[0]) + printable_text = new_text - generated_text, printable_text = "", "" - for new_text in self.streamer: - generated_text += new_text - printable_text += new_text - sentences = sent_tokenize(printable_text) - if len(sentences) > 1: - yield(sentences[0]) - printable_text = new_text - - self.chat.append( - {"role": "assistant", "content": generated_text} - ) + self.chat.append({"role": "assistant", "content": generated_text}) # don't forget last sentence yield printable_text @@ -694,33 +666,37 @@ class ParlerTTSHandlerArguments: default="ylacombe/parler-tts-mini-jenny-30H", metadata={ "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'." - } + }, ) tts_device: str = field( default="cuda", metadata={ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." - } + }, ) tts_torch_dtype: str = field( default="float16", metadata={ "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." - } + }, ) tts_compile_mode: str = field( default=None, metadata={ "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" - } + }, ) tts_gen_min_new_tokens: int = field( - default=None, - metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"} + default=64, + metadata={ + "help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs" + }, ) tts_gen_max_new_tokens: int = field( default=512, - metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"} + metadata={ + "help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs" + }, ) description: str = field( default=( @@ -729,38 +705,39 @@ class ParlerTTSHandlerArguments: ), metadata={ "help": "Description of the speaker's voice and speaking style to guide the TTS model." - } + }, ) play_steps_s: float = field( - default=0.2, + default=1.0, metadata={ "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds." - } + }, ) max_prompt_pad_length: int = field( default=8, metadata={ "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible." - } - ) + }, + ) class ParlerTTSHandler(BaseHandler): def setup( - self, - should_listen, - model_name="ylacombe/parler-tts-mini-jenny-30H", - device="cuda", - torch_dtype="float16", - compile_mode=None, - gen_kwargs={}, - max_prompt_pad_length=8, - description=( - "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " - "She speaks very fast." - ), - play_steps_s=1 - ): + self, + should_listen, + model_name="ylacombe/parler-tts-mini-jenny-30H", + device="cuda", + torch_dtype="float16", + compile_mode=None, + gen_kwargs={}, + max_prompt_pad_length=8, + description=( + "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " + "She speaks very fast." + ), + play_steps_s=1, + blocksize=512, + ): self.should_listen = should_listen self.device = device self.torch_dtype = getattr(torch, torch_dtype) @@ -769,23 +746,27 @@ class ParlerTTSHandler(BaseHandler): self.max_prompt_pad_length = max_prompt_pad_length self.description = description - self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) + self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = ParlerTTSForConditionalGeneration.from_pretrained( - model_name, - torch_dtype=self.torch_dtype + model_name, torch_dtype=self.torch_dtype ).to(device) - + framerate = self.model.audio_encoder.config.frame_rate self.play_steps = int(framerate * play_steps_s) + self.blocksize = blocksize if self.compile_mode not in (None, "default"): - logger.warning("Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'") + logger.warning( + "Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'" + ) self.compile_mode = "default" if self.compile_mode: self.model.generation_config.cache_implementation = "static" - self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True) + self.model.forward = torch.compile( + self.model.forward, mode=self.compile_mode, fullgraph=True + ) self.warmup() @@ -795,13 +776,19 @@ class ParlerTTSHandler(BaseHandler): max_length_prompt=50, pad=False, ): - pad_args_prompt = {"padding": "max_length", "max_length": max_length_prompt} if pad else {} + pad_args_prompt = ( + {"padding": "max_length", "max_length": max_length_prompt} if pad else {} + ) - tokenized_description = self.description_tokenizer(self.description, return_tensors="pt") + tokenized_description = self.description_tokenizer( + self.description, return_tensors="pt" + ) input_ids = tokenized_description.input_ids.to(self.device) attention_mask = tokenized_description.attention_mask.to(self.device) - tokenized_prompt = self.prompt_tokenizer(prompt, return_tensors="pt", **pad_args_prompt) + tokenized_prompt = self.prompt_tokenizer( + prompt, return_tensors="pt", **pad_args_prompt + ) prompt_input_ids = tokenized_prompt.input_ids.to(self.device) prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device) @@ -810,29 +797,29 @@ class ParlerTTSHandler(BaseHandler): "attention_mask": attention_mask, "prompt_input_ids": prompt_input_ids, "prompt_attention_mask": prompt_attention_mask, - **self.gen_kwargs + **self.gen_kwargs, } return gen_kwargs - + def warmup(self): logger.info(f"Warming up {self.__class__.__name__}") - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + if self.device == "cuda": + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) - # 2 warmup steps for no compile or compile mode with CUDA graphs capture + # 2 warmup steps for no compile or compile mode with CUDA graphs capture n_steps = 1 if self.compile_mode == "default" else 2 - torch.cuda.synchronize() - start_event.record() + if self.device == "cuda": + torch.cuda.synchronize() + start_event.record() if self.compile_mode: pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)] for pad_length in pad_lengths[::-1]: model_kwargs = self.prepare_model_inputs( - "dummy prompt", - max_length_prompt=pad_length, - pad=True + "dummy prompt", max_length_prompt=pad_length, pad=True ) for _ in range(n_steps): _ = self.model.generate(**model_kwargs) @@ -840,12 +827,14 @@ class ParlerTTSHandler(BaseHandler): else: model_kwargs = self.prepare_model_inputs("dummy prompt") for _ in range(n_steps): - _ = self.model.generate(**model_kwargs) - - end_event.record() - torch.cuda.synchronize() - logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") + _ = self.model.generate(**model_kwargs) + if self.device == "cuda": + end_event.record() + torch.cuda.synchronize() + logger.info( + f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" + ) def process(self, llm_sentence): console.print(f"[green]ASSISTANT: {llm_sentence}") @@ -858,26 +847,32 @@ class ParlerTTSHandler(BaseHandler): logger.debug(f"padding to {pad_length}") pad_args["pad"] = True pad_args["max_length_prompt"] = pad_length - + tts_gen_kwargs = self.prepare_model_inputs( llm_sentence, **pad_args, ) - streamer = ParlerTTSStreamer(self.model, device=self.device, play_steps=self.play_steps) - tts_gen_kwargs = { - "streamer": streamer, - **tts_gen_kwargs - } + streamer = ParlerTTSStreamer( + self.model, device=self.device, play_steps=self.play_steps + ) + tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs} torch.manual_seed(0) thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs) thread.start() for i, audio_chunk in enumerate(streamer): if i == 0: - logger.info(f"Time to first audio: {perf_counter() - pipeline_start:.3f}") - audio_chunk = np.int16(audio_chunk * 32767) - yield audio_chunk + logger.info( + f"Time to first audio: {perf_counter() - pipeline_start:.3f}" + ) + audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) + audio_chunk = (audio_chunk * 32768).astype(np.int16) + for i in range(0, len(audio_chunk), self.blocksize): + yield np.pad( + audio_chunk[i : i + self.blocksize], + (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), + ) self.should_listen.set() @@ -891,7 +886,7 @@ def prepare_args(args, prefix): for key in copy(args.__dict__): if key.startswith(prefix): value = args.__dict__.pop(key) - new_key = key[len(prefix) + 1:] # Remove prefix and underscore + new_key = key[len(prefix) + 1 :] # Remove prefix and underscore if new_key.startswith("gen_"): gen_kwargs[new_key[4:]] = value # Remove 'gen_' and add to dict else: @@ -901,37 +896,39 @@ def prepare_args(args, prefix): def main(): - parser = HfArgumentParser(( - ModuleArguments, - SocketReceiverArguments, - SocketSenderArguments, - VADHandlerArguments, - WhisperSTTHandlerArguments, - LanguageModelHandlerArguments, - ParlerTTSHandlerArguments, - )) + parser = HfArgumentParser( + ( + ModuleArguments, + SocketReceiverArguments, + SocketSenderArguments, + VADHandlerArguments, + WhisperSTTHandlerArguments, + LanguageModelHandlerArguments, + ParlerTTSHandlerArguments, + ) + ) # 0. Parse CLI arguments if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Parse configurations from a JSON file if specified ( module_kwargs, - socket_receiver_kwargs, - socket_sender_kwargs, - vad_handler_kwargs, - whisper_stt_handler_kwargs, - language_model_handler_kwargs, + socket_receiver_kwargs, + socket_sender_kwargs, + vad_handler_kwargs, + whisper_stt_handler_kwargs, + language_model_handler_kwargs, parler_tts_handler_kwargs, ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: # Parse arguments from command line if no JSON file is provided ( module_kwargs, - socket_receiver_kwargs, - socket_sender_kwargs, - vad_handler_kwargs, - whisper_stt_handler_kwargs, - language_model_handler_kwargs, + socket_receiver_kwargs, + socket_sender_kwargs, + vad_handler_kwargs, + whisper_stt_handler_kwargs, + language_model_handler_kwargs, parler_tts_handler_kwargs, ) = parser.parse_args_into_dataclasses() @@ -939,7 +936,7 @@ def main(): global logger logging.basicConfig( level=module_kwargs.log_level.upper(), - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) @@ -948,20 +945,62 @@ def main(): torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) # 2. Prepare each part's arguments + def overwrite_device_argument(common_device: Optional[str], *handler_kwargs): + if common_device: + for kwargs in handler_kwargs: + if hasattr(kwargs, "lm_device"): + kwargs.lm_device = common_device + if hasattr(kwargs, "tts_device"): + kwargs.tts_device = common_device + if hasattr(kwargs, "stt_device"): + kwargs.stt_device = common_device + + # Call this function with the common device and all the handlers + overwrite_device_argument( + module_kwargs.device, + language_model_handler_kwargs, + parler_tts_handler_kwargs, + whisper_stt_handler_kwargs, + ) + prepare_args(whisper_stt_handler_kwargs, "stt") prepare_args(language_model_handler_kwargs, "lm") - prepare_args(parler_tts_handler_kwargs, "tts") + prepare_args(parler_tts_handler_kwargs, "tts") # 3. Build the pipeline stop_event = Event() # used to stop putting received audio chunks in queue until all setences have been processed by the TTS - should_listen = Event() + should_listen = Event() recv_audio_chunks_queue = Queue() send_audio_chunks_queue = Queue() - spoken_prompt_queue = Queue() + spoken_prompt_queue = Queue() text_prompt_queue = Queue() lm_response_queue = Queue() - + + if module_kwargs.mode == "local": + local_audio_streamer = LocalAudioStreamer( + input_queue=recv_audio_chunks_queue, output_queue=send_audio_chunks_queue + ) + comms_handlers = [local_audio_streamer] + should_listen.set() + else: + comms_handlers = [ + SocketReceiver( + stop_event, + recv_audio_chunks_queue, + should_listen, + host=socket_receiver_kwargs.recv_host, + port=socket_receiver_kwargs.recv_port, + chunk_size=socket_receiver_kwargs.chunk_size, + ), + SocketSender( + stop_event, + send_audio_chunks_queue, + host=socket_sender_kwargs.send_host, + port=socket_sender_kwargs.send_port, + ), + ] + vad = VADHandler( stop_event, queue_in=recv_audio_chunks_queue, @@ -969,7 +1008,7 @@ def main(): setup_args=(should_listen,), setup_kwargs=vars(vad_handler_kwargs), ) - stt = WhisperSTTHandler( + stt = LightningWhisperSTTHandler( stop_event, queue_in=spoken_prompt_queue, queue_out=text_prompt_queue, @@ -981,37 +1020,28 @@ def main(): queue_out=lm_response_queue, setup_kwargs=vars(language_model_handler_kwargs), ) - tts = ParlerTTSHandler( + # tts = ParlerTTSHandler( + # stop_event, + # queue_in=lm_response_queue, + # queue_out=send_audio_chunks_queue, + # setup_args=(should_listen,), + # setup_kwargs=vars(parler_tts_handler_kwargs), + # ) + tts = MeloTTSHandler( stop_event, queue_in=lm_response_queue, queue_out=send_audio_chunks_queue, setup_args=(should_listen,), - setup_kwargs=vars(parler_tts_handler_kwargs), - ) - - recv_handler = SocketReceiver( - stop_event, - recv_audio_chunks_queue, - should_listen, - host=socket_receiver_kwargs.recv_host, - port=socket_receiver_kwargs.recv_port, - chunk_size=socket_receiver_kwargs.chunk_size, ) - send_handler = SocketSender( - stop_event, - send_audio_chunks_queue, - host=socket_sender_kwargs.send_host, - port=socket_sender_kwargs.send_port, - ) - # 4. Run the pipeline try: - pipeline_manager = ThreadManager([vad, tts, lm, stt, recv_handler, send_handler]) + pipeline_manager = ThreadManager([*comms_handlers, vad, stt, lm, tts]) pipeline_manager.start() except KeyboardInterrupt: pipeline_manager.stop() - + + if __name__ == "__main__": main() diff --git a/utils.py b/utils.py index 6a353c7f136f0af6cc8e22c53b97642a0ad47e19..f4237a1121e5a398e09bb8249bd37a9bffa81b43 100644 --- a/utils.py +++ b/utils.py @@ -2,8 +2,8 @@ import numpy as np import torch -def next_power_of_2(x): - return 1 if x == 0 else 2**(x - 1).bit_length() +def next_power_of_2(x): + return 1 if x == 0 else 2 ** (x - 1).bit_length() def int2float(sound): @@ -12,22 +12,22 @@ def int2float(sound): """ abs_max = np.abs(sound).max() - sound = sound.astype('float32') + sound = sound.astype("float32") if abs_max > 0: - sound *= 1/32768 + sound *= 1 / 32768 sound = sound.squeeze() # depends on the use case return sound class VADIterator: - def __init__(self, - model, - threshold: float = 0.5, - sampling_rate: int = 16000, - min_silence_duration_ms: int = 100, - speech_pad_ms: int = 30 - ): - + def __init__( + self, + model, + threshold: float = 0.5, + sampling_rate: int = 16000, + min_silence_duration_ms: int = 100, + speech_pad_ms: int = 30, + ): """ Mainly taken from https://github.com/snakers4/silero-vad Class for stream imitation @@ -57,14 +57,15 @@ class VADIterator: self.buffer = [] if sampling_rate not in [8000, 16000]: - raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]') + raise ValueError( + "VADIterator does not support sampling rates other than [8000, 16000]" + ) self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000 self.reset_states() def reset_states(self): - self.model.reset_states() self.triggered = False self.temp_end = 0 @@ -107,11 +108,11 @@ class VADIterator: # end of speak self.temp_end = 0 self.triggered = False - spoken_utterance = self.buffer + spoken_utterance = self.buffer self.buffer = [] return spoken_utterance - + if self.triggered: self.buffer.append(x) - + return None