diff --git a/listen_and_play.py b/listen_and_play.py index 3c3fc583aff2328fe00f4bb4304c88a99c5b6af7..285fe553393ad9c913bea18a76cf6ebe3435c149 100644 --- a/listen_and_play.py +++ b/listen_and_play.py @@ -1,88 +1,129 @@ import socket -import sounddevice as sd import numpy as np import threading from queue import Queue - -CHUNK = 1024 -CHANNELS = 1 -SEND_RATE = 16000 -RECV_RATE = 44100 - -HOST = '172.16.128.13' -PORT = 12345 -RECV_PORT = 12346 - -send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) -send_socket.connect((HOST, PORT)) - -recv_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) -recv_socket.connect((HOST, RECV_PORT)) - -print("Recording and streaming...") - -stop_event = threading.Event() -recv_queue = Queue() -send_queue = Queue() - -def callback_recv(outdata, frames, time, status): - if not recv_queue.empty(): - data = recv_queue.get() - outdata[:len(data)] = data - outdata[len(data):] = b'\x00' * (len(outdata) - len(data)) - else: - outdata[:] = b'\x00' * len(outdata) - -def callback_send(indata, frames, time, status): - if recv_queue.empty(): - data = bytes(indata) - send_queue.put(data) - -def send(stop_event, send_queue): - while not stop_event.is_set(): - data = send_queue.get() - send_socket.sendall(data) - -def recv(stop_event, recv_queue): - - def receive_full_chunk(conn, chunk_size): - data = b'' - while len(data) < chunk_size: - packet = conn.recv(chunk_size - len(data)) - if not packet: - return None # Connection has been closed - data += packet - return data - - while not stop_event.is_set(): - data = receive_full_chunk(recv_socket, CHUNK * 2) - if data: - recv_queue.put(data) - -try: - send_stream = sd.RawInputStream(samplerate=SEND_RATE, channels=CHANNELS, dtype='int16', blocksize=CHUNK, callback=callback_send) - recv_stream = sd.RawOutputStream(samplerate=RECV_RATE, channels=CHANNELS, dtype='int16', blocksize=CHUNK, callback=callback_recv) - threading.Thread(target=send_stream.start).start() - threading.Thread(target=recv_stream.start).start() - - send_thread = threading.Thread(target=send, args=(stop_event, send_queue)) - send_thread.start() - recv_thread = threading.Thread(target=recv, args=(stop_event, recv_queue)) - recv_thread.start() - - input("Press Enter to stop...") - -except KeyboardInterrupt: - print("Finished streaming.") - -finally: - stop_event.set() - recv_thread.join() - print("1") - - send_thread.join() - print("2") - - send_socket.close() - recv_socket.close() - print("Connection closed.") \ No newline at end of file +from dataclasses import dataclass, field + +@dataclass +class ListenAndPlayArguments: + send_rate: int = field( + default=16000, + metadata={ + "help": "In Hz. Default is 16000." + } + ) + recv_rate: int = field( + default=44100, + metadata={ + "help": "In Hz. Default is 44100." + } + ) + list_play_chunk_size: int = field( + default=1024, + metadata={ + "help": "The size of data chunks (in bytes). Default is 1024." + } + ) + listen_play_host: str = field( + default="localhost", + metadata={ + "help": "The hostname or IP address for listening and playing. Default is 'localhost'." + } + ) + listen_play_send_port: int = field( + default=12345, + metadata={ + "help": "The network port for sending data. Default is 12345." + } + ) + listen_play_recv_port: int = field( + default=12346, + metadata={ + "help": "The network port for receiving data. Default is 12346." + } + ) + + +def listen_and_play( + send_rate=16000, + recv_rate=44100, + list_play_chunk_size=1024, + listen_play_host="localhost", + listen_play_send_port=12345, + listen_play_recv_port=12346, +): + import sounddevice as sd + + send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + send_socket.connect((listen_play_host, listen_play_send_port)) + + recv_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + recv_socket.connect((listen_play_host, listen_play_recv_port)) + + print("Recording and streaming...") + + stop_event = threading.Event() + recv_queue = Queue() + send_queue = Queue() + + def callback_recv(outdata, frames, time, status): + if not recv_queue.empty(): + data = recv_queue.get() + outdata[:len(data)] = data + outdata[len(data):] = b'\x00' * (len(outdata) - len(data)) + else: + outdata[:] = b'\x00' * len(outdata) + + def callback_send(indata, frames, time, status): + if recv_queue.empty(): + data = bytes(indata) + send_queue.put(data) + + def send(stop_event, send_queue): + while not stop_event.is_set(): + data = send_queue.get() + send_socket.sendall(data) + + def recv(stop_event, recv_queue): + + def receive_full_chunk(conn, chunk_size): + data = b'' + while len(data) < chunk_size: + packet = conn.recv(chunk_size - len(data)) + if not packet: + return None # Connection has been closed + data += packet + return data + + while not stop_event.is_set(): + data = receive_full_chunk(recv_socket, list_play_chunk_size * 2) + if data: + recv_queue.put(data) + + try: + send_stream = sd.RawInputStream(samplerate=send_rate, channels=1, dtype='int16', blocksize=list_play_chunk_size, callback=callback_send) + recv_stream = sd.RawOutputStream(samplerate=recv_rate, channels=1, dtype='int16', blocksize=list_play_chunk_size, callback=callback_recv) + threading.Thread(target=send_stream.start).start() + threading.Thread(target=recv_stream.start).start() + + send_thread = threading.Thread(target=send, args=(stop_event, send_queue)) + send_thread.start() + recv_thread = threading.Thread(target=recv, args=(stop_event, recv_queue)) + recv_thread.start() + + input("Press Enter to stop...") + + except KeyboardInterrupt: + print("Finished streaming.") + + finally: + stop_event.set() + recv_thread.join() + print("1") + + send_thread.join() + print("2") + + send_socket.close() + recv_socket.close() + print("Connection closed.") \ No newline at end of file diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 77434287327600332f7de7e2ebed5f65f19b8109..25b6f3008ac5e275ab170be49b4857062a3793e4 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -4,6 +4,11 @@ import threading from threading import Thread, Event from queue import Queue from time import perf_counter +import sys +import os +from dataclasses import dataclass, field +from copy import copy +import multiprocessing import numpy as np import soundfile as sf @@ -16,21 +21,41 @@ from transformers import ( AutoProcessor, AutoTokenizer, pipeline, - TextIteratorStreamer + TextIteratorStreamer, + HfArgumentParser +) +from parler_tts import ( + ParlerTTSForConditionalGeneration, + ParlerTTSStreamer, ) -from parler_tts import ParlerTTSForConditionalGeneration - -# Local module imports -from utils import VADIterator, int2float, ParlerTTSStreamer -# Setup logging -logging.basicConfig( - level=logging.DEBUG, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', +from utils import ( + VADIterator, + int2float, ) -logger = logging.getLogger(__name__) + +from listen_and_play import ListenAndPlayArguments + + console = Console() +@dataclass +class ModuleArguments: + log_level: str = field( + default="info", + metadata={ + "help": "Provide logging level. Example --log_level debug, default=warning." + } + ) + client: bool = field( + default=False, + metadata={"help": "Whether or not this module is run as client."} + ) + server: bool = field( + default=False, + metadata={"help": "Whether or not this module is run as server."} + ) + class ThreadManager: def __init__(self, handlers): @@ -49,8 +74,10 @@ class ThreadManager: for thread in self.threads: thread.join() + pipeline_start = None + class BaseHandler: def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}): self.stop_event = stop_event @@ -90,6 +117,29 @@ class BaseHandler: pass +@dataclass +class SocketReceiverArguments: + recv_host: str = field( + default="localhost", + metadata={ + "help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all " + "available interfaces on the host machine." + } + ) + recv_port: int = field( + default=12345, + metadata={ + "help": "The port number on which the socket server listens. Default is 12346." + } + ) + chunk_size: int = field( + default=1024, + metadata={ + "help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes." + } + ) + + class SocketReceiver: def __init__( self, @@ -99,17 +149,13 @@ class SocketReceiver: host='0.0.0.0', port=12345, chunk_size=1024 - ): + ): self.stop_event = stop_event self.queue_out = queue_out self.should_listen = should_listen self.chunk_size=chunk_size - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind((host, port)) - self.socket.listen(1) - self.conn, _ = self.socket.accept() - logger.debug("receiver connected") + self.host = host + self.port = port def receive_full_chunk(self, conn, chunk_size): data = b'' @@ -122,6 +168,13 @@ class SocketReceiver: return data def run(self): + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.bind((self.host, self.port)) + self.socket.listen(1) + self.conn, _ = self.socket.accept() + logger.debug("receiver connected") + self.should_listen.set() while not self.stop_event.is_set(): audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size) @@ -133,8 +186,25 @@ class SocketReceiver: self.queue_out.put(audio_chunk) self.conn.close() logger.debug("Receiver closed") - + +@dataclass +class SocketSenderArguments: + send_host: str = field( + default="localhost", + metadata={ + "help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all " + "available interfaces on the host machine." + } + ) + send_port: int = field( + default=12346, + metadata={ + "help": "The port number on which the socket server listens. Default is 12346." + } + ) + + class SocketSender: def __init__( self, @@ -145,14 +215,18 @@ class SocketSender: ): self.stop_event = stop_event self.queue_in = queue_in + self.host = host + self.port = port + + + def run(self): self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind((host, port)) + self.socket.bind((self.host, self.port)) self.socket.listen(1) self.conn, _ = self.socket.accept() logger.debug("sender connected") - def run(self): while not self.stop_event.is_set(): audio_chunk = self.queue_in.get() self.conn.sendall(audio_chunk) @@ -162,6 +236,46 @@ class SocketSender: logger.debug("Sender closed") +@dataclass +class VADHandlerArguments: + thresh: float = field( + default=0.3, + metadata={ + "help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection." + } + ) + sample_rate: int = field( + default=16000, + metadata={ + "help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio." + } + ) + min_silence_ms: int = field( + default=1000, + metadata={ + "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 1000 ms." + } + ) + min_speech_ms: int = field( + default=500, + metadata={ + "help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms." + } + ) + max_speech_ms: float = field( + default=float('inf'), + metadata={ + "help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments." + } + ) + speech_pad_ms: int = field( + default=30, + metadata={ + "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 30 ms." + } + ) + + class VADHandler(BaseHandler): def setup( self, @@ -204,17 +318,69 @@ class VADHandler(BaseHandler): yield array -class WhisperSTTProcessor(BaseHandler): +@dataclass +class WhisperSTTHandlerArguments: + stt_model_name: str = field( + default="distil-whisper/distil-large-v3", + metadata={ + "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'." + } + ) + stt_device: str = field( + default="cuda", + metadata={ + "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." + } + ) + stt_torch_dtype: str = field( + default="float16", + metadata={ + "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." + } + ) + stt_gen_max_new_tokens: int = field( + default=128, + metadata={ + "help": "The maximum number of new tokens to generate. Default is 128." + } + ) + stt_gen_num_beams: int = field( + default=1, + metadata={ + "help": "The number of beams for beam search. Default is 1, implying greedy decoding." + } + ) + stt_gen_return_timestamps: bool = field( + default=False, + metadata={ + "help": "Whether to return timestamps with transcriptions. Default is False." + } + ) + stt_gen_task: str = field( + default="transcribe", + metadata={ + "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'." + } + ) + stt_gen_language: str = field( + default="en", + metadata={ + "help": "The language of the speech to transcribe. Default is 'en' for English." + } + ) + + +class WhisperSTTHandler(BaseHandler): def setup( self, model_name="distil-whisper/distil-large-v3", device="cuda", - torch_dtype=torch.float16, + torch_dtype="float16", gen_kwargs={} ): self.processor = AutoProcessor.from_pretrained(model_name) self.device = device - self.torch_dtype = torch_dtype + self.torch_dtype = getattr(torch, torch_dtype) self.model = AutoModelForSpeechSeq2Seq.from_pretrained( model_name, torch_dtype=self.torch_dtype, @@ -239,12 +405,64 @@ class WhisperSTTProcessor(BaseHandler): yield pred_text +@dataclass +class LanguageModelHandlerArguments: + llm_model_name: str = field( + default="microsoft/Phi-3-mini-4k-instruct", + metadata={ + "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." + } + ) + llm_device: str = field( + default="cuda", + metadata={ + "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." + } + ) + llm_torch_dtype: str = field( + default="float16", + metadata={ + "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." + } + ) + user_role: str = field( + default="user", + metadata={ + "help": "Role assigned to the user in the chat context. Default is 'user'." + } + ) + init_chat_role: str = field( + default="system", + metadata={ + "help": "Initial role for setting up the chat context. Default is 'system'." + } + ) + init_chat_prompt: str = field( + default="You are a helpful AI assistant.", + metadata={ + "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" + } + ) + llm_gen_max_new_tokens: int = field( + default=128, + metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 128."} + ) + llm_gen_temperature: float = field( + default=0.0, + metadata={"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."} + ) + llm_gen_do_sample: bool = field( + default=False, + metadata={"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."} + ) + + class LanguageModelHandler(BaseHandler): def setup( self, model_name="microsoft/Phi-3-mini-4k-instruct", device="cuda", - torch_dtype=torch.float16, + torch_dtype="float16", gen_kwargs={}, user_role="user", init_chat_role="system", @@ -271,6 +489,7 @@ class LanguageModelHandler(BaseHandler): ] self.gen_kwargs = { "streamer": self.streamer, + "return_full_text": False, **gen_kwargs } self.user_role = user_role @@ -297,13 +516,56 @@ class LanguageModelHandler(BaseHandler): yield printable_text -class ParlerTTSProcessor(BaseHandler): +@dataclass +class ParlerTTSHandlerArguments: + tts_model_name: str = field( + default="ylacombe/parler-tts-mini-jenny-30H", + metadata={ + "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'." + } + ) + tts_device: str = field( + default="cuda", + metadata={ + "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." + } + ) + tts_torch_dtype: str = field( + default="float16", + metadata={ + "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." + } + ) + gen_kwargs: dict = field( + default_factory=dict, + metadata={ + "help": "Additional keyword arguments to pass to the model's generate method. Use this to customize generation settings." + } + ) + description: str = field( + default=( + "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " + "She speaks very fast." + ), + metadata={ + "help": "Description of the speaker's voice and speaking style to guide the TTS model." + } + ) + play_steps_s: float = field( + default=0.5, + metadata={ + "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds." + } + ) + + +class ParlerTTSHandler(BaseHandler): def setup( self, should_listen, model_name="ylacombe/parler-tts-mini-jenny-30H", device="cuda", - torch_dtype=torch.float32, + torch_dtype="float16", gen_kwargs={}, description=( "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " @@ -311,6 +573,7 @@ class ParlerTTSProcessor(BaseHandler): ), play_steps_s=0.5 ): + torch_dtype = getattr(torch, torch_dtype) self._should_listen = should_listen self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -361,93 +624,145 @@ class ParlerTTSProcessor(BaseHandler): self._should_listen.set() -def main(): +def prepare_args(args, prefix): + gen_kwargs = {} + for key in copy(args.__dict__): + if key.startswith(prefix): + value = args.__dict__.pop(key) + new_key = key[len(prefix) + 1:] # Remove prefix and underscore + if new_key.startswith("gen_"): + gen_kwargs[new_key[4:]] = value # Remove 'gen_' and add to dict + else: + args.__dict__[new_key] = value + + args.__dict__["gen_kwargs"] = gen_kwargs - device = "cuda:1" - - stt_gen_kwargs = { - "max_new_tokens": 128, - "num_beams": 1, - "return_timestamps": False, - "task": "transcribe", - "language": "en", - } - - llm_gen_kwargs = { - "max_new_tokens": 128, - "return_full_text": False, - "temperature": 0.0, - "do_sample": False, - } - - tts_gen_kwargs = { - "min_new_tokens": 10, - "temperature": 1.0, - "do_sample": True, - } - - stop_event = Event() - should_listen = Event() - - recv_audio_chunks_queue = Queue() - send_audio_chunks_queue = Queue() - spoken_prompt_queue = Queue() - text_prompt_queue = Queue() - llm_response_queue = Queue() - - vad = VADHandler( - stop_event, - queue_in=recv_audio_chunks_queue, - queue_out=spoken_prompt_queue, - setup_args=(should_listen,) - ) - stt = WhisperSTTProcessor( - stop_event, - queue_in=spoken_prompt_queue, - queue_out=text_prompt_queue, - setup_kwargs={ - "device": device, - "gen_kwargs": stt_gen_kwargs, - }, - ) - llm = LanguageModelHandler( - stop_event, - queue_in=text_prompt_queue, - queue_out=llm_response_queue, - setup_kwargs={ - "device": device, - "gen_kwargs": llm_gen_kwargs, - }, - ) - tts = ParlerTTSProcessor( - stop_event, - queue_in=llm_response_queue, - queue_out=send_audio_chunks_queue, - setup_args=(should_listen,), - setup_kwargs={ - "device": device, - "gen_kwargs": tts_gen_kwargs - }, - ) - - recv_handler = SocketReceiver( - stop_event, - recv_audio_chunks_queue, - should_listen, - ) - send_handler = SocketSender( - stop_event, - send_audio_chunks_queue, +def main(): + parser = HfArgumentParser(( + ModuleArguments, + SocketReceiverArguments, + SocketSenderArguments, + VADHandlerArguments, + WhisperSTTHandlerArguments, + LanguageModelHandlerArguments, + ParlerTTSHandlerArguments, + ListenAndPlayArguments + )) + + # 0. Parse CLI arguments + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # Parse configurations from a JSON file if specified + ( + module_kwargs, + socket_receiver_kwargs, + socket_sender_kwargs, + vad_handler_kwargs, + whisper_stt_handler_kwargs, + language_model_handler_kwargs, + parler_tts_handler_kwargs, + listen_and_play_kwargs, + ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + # Parse arguments from command line if no JSON file is provided + ( + module_kwargs, + socket_receiver_kwargs, + socket_sender_kwargs, + vad_handler_kwargs, + whisper_stt_handler_kwargs, + language_model_handler_kwargs, + parler_tts_handler_kwargs, + listen_and_play_kwargs + ) = parser.parse_args_into_dataclasses() + + global logger + logging.basicConfig( + level=module_kwargs.log_level.upper(), + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', ) + logger = logging.getLogger(__name__) + + prepare_args(whisper_stt_handler_kwargs, "stt") + prepare_args(language_model_handler_kwargs, "llm") + prepare_args(parler_tts_handler_kwargs, "tts") + if not module_kwargs.client and not module_kwargs.server: + # is equivalent as behaving as both client and server + module_kwargs.client = True + module_kwargs.server = True + + if module_kwargs.client and module_kwargs.server and (socket_receiver_kwargs.recv_host != "127.0.0.1" or socket_sender_kwargs.recv_host != "localhost"): + raise ValueError() + try: - pipeline_manager = ThreadManager([vad, tts, llm, stt, recv_handler, send_handler]) - pipeline_manager.start() + if module_kwargs.server: + stop_event = Event() + should_listen = Event() + recv_audio_chunks_queue = Queue() + send_audio_chunks_queue = Queue() + spoken_prompt_queue = Queue() + text_prompt_queue = Queue() + llm_response_queue = Queue() + + vad = VADHandler( + stop_event, + queue_in=recv_audio_chunks_queue, + queue_out=spoken_prompt_queue, + setup_args=(should_listen,), + setup_kwargs=vars(vad_handler_kwargs), + ) + stt = WhisperSTTHandler( + stop_event, + queue_in=spoken_prompt_queue, + queue_out=text_prompt_queue, + setup_kwargs=vars(whisper_stt_handler_kwargs), + ) + llm = LanguageModelHandler( + stop_event, + queue_in=text_prompt_queue, + queue_out=llm_response_queue, + setup_kwargs=vars(language_model_handler_kwargs), + ) + tts = ParlerTTSHandler( + stop_event, + queue_in=llm_response_queue, + queue_out=send_audio_chunks_queue, + setup_args=(should_listen,), + setup_kwargs=vars(parler_tts_handler_kwargs), + ) + + recv_handler = SocketReceiver( + stop_event, + recv_audio_chunks_queue, + should_listen, + host=socket_receiver_kwargs.recv_host, + port=socket_receiver_kwargs.recv_port, + chunk_size=socket_receiver_kwargs.chunk_size, + ) + + send_handler = SocketSender( + stop_event, + send_audio_chunks_queue, + host=socket_sender_kwargs.send_host, + port=socket_sender_kwargs.send_port, + ) + + pipeline_manager = ThreadManager([vad, tts, llm, stt, recv_handler, send_handler]) + pipeline_manager.start() + + if module_kwargs.client: + from listen_and_play import listen_and_play + + listen_and_play_process = multiprocessing.Process( + target=listen_and_play, + kwargs=vars(listen_and_play_kwargs), + ) + listen_and_play_process.start() except KeyboardInterrupt: pipeline_manager.stop() - + listen_and_play_process.join() if __name__ == "__main__": main() diff --git a/utils.py b/utils.py index 462eba61d81cbf35654511e76bd44576b0c05954..4a3c621a54e087a7d958b8d01a0fc2063ea94313 100644 --- a/utils.py +++ b/utils.py @@ -32,6 +32,7 @@ def int2float(sound): sound = sound.squeeze() # depends on the use case return sound + class VADIterator: def __init__(self, model, @@ -127,132 +128,3 @@ class VADIterator: self.buffer.append(x) return None - -class ParlerTTSStreamer(BaseStreamer): - def __init__( - self, - model: ParlerTTSForConditionalGeneration, - device: Optional[str] = None, - play_steps: Optional[int] = 10, - stride: Optional[int] = None, - timeout: Optional[float] = None, - ): - """ - Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is - useful for applications that benefit from accessing the generated audio in a non-blocking way (e.g. in an interactive - Gradio demo). - Parameters: - model (`ParlerTTSForConditionalGeneration`): - The Parler-TTS model used to generate the audio waveform. - device (`str`, *optional*): - The torch device on which to run the computation. If `None`, will default to the device of the model. - play_steps (`int`, *optional*, defaults to 10): - The number of generation steps with which to return the generated audio array. Using fewer steps will - mean the first chunk is ready faster, but will require more codec decoding steps overall. This value - should be tuned to your device and latency requirements. - stride (`int`, *optional*): - The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces - the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to - play_steps // 6 in the audio space. - timeout (`int`, *optional*): - The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions - in `.generate()`, when it is called in a separate thread. - """ - self.decoder = model.decoder - self.audio_encoder = model.audio_encoder - self.generation_config = model.generation_config - self.device = device if device is not None else model.device - - # variables used in the streaming process - self.play_steps = play_steps - if stride is not None: - self.stride = stride - else: - hop_length = math.floor(self.audio_encoder.config.sampling_rate / self.audio_encoder.config.frame_rate) - self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6 - self.token_cache = None - self.to_yield = 0 - - # varibles used in the thread process - self.audio_queue = Queue() - self.stop_signal = None - self.timeout = timeout - - def apply_delay_pattern_mask(self, input_ids): - # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler) - _, delay_pattern_mask = self.decoder.build_delay_pattern_mask( - input_ids[:, :1], - bos_token_id=self.generation_config.bos_token_id, - pad_token_id=self.generation_config.decoder_start_token_id, - max_length=input_ids.shape[-1], - ) - # apply the pattern mask to the input ids - input_ids = self.decoder.apply_delay_pattern_mask(input_ids, delay_pattern_mask) - - # revert the pattern delay mask by filtering the pad token id - mask = (delay_pattern_mask != self.generation_config.bos_token_id) & (delay_pattern_mask != self.generation_config.pad_token_id) - input_ids = input_ids[mask].reshape(1, self.decoder.num_codebooks, -1) - # append the frame dimension back to the audio codes - input_ids = input_ids[None, ...] - - # send the input_ids to the correct device - input_ids = input_ids.to(self.audio_encoder.device) - - decode_sequentially = ( - self.generation_config.bos_token_id in input_ids - or self.generation_config.pad_token_id in input_ids - or self.generation_config.eos_token_id in input_ids - ) - if not decode_sequentially: - output_values = self.audio_encoder.decode( - input_ids, - audio_scales=[None], - ) - else: - sample = input_ids[:, 0] - sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0 - sample = sample[:, :, sample_mask] - output_values = self.audio_encoder.decode(sample[None, ...], [None]) - - audio_values = output_values.audio_values[0, 0] - return audio_values.cpu().float().numpy() - - def put(self, value): - batch_size = value.shape[0] // self.decoder.num_codebooks - if batch_size > 1: - raise ValueError("ParlerTTSStreamer only supports batch size 1") - - if self.token_cache is None: - self.token_cache = value - else: - self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1) - - if self.token_cache.shape[-1] % self.play_steps == 0: - audio_values = self.apply_delay_pattern_mask(self.token_cache) - self.on_finalized_audio(audio_values[self.to_yield : -self.stride]) - self.to_yield += len(audio_values) - self.to_yield - self.stride - - def end(self): - """Flushes any remaining cache and appends the stop symbol.""" - if self.token_cache is not None: - audio_values = self.apply_delay_pattern_mask(self.token_cache) - else: - audio_values = np.zeros(self.to_yield) - - self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True) - - def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False): - """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue.""" - self.audio_queue.put(audio, timeout=self.timeout) - if stream_end: - self.audio_queue.put(self.stop_signal, timeout=self.timeout) - - def __iter__(self): - return self - - def __next__(self): - value = self.audio_queue.get(timeout=self.timeout) - if not isinstance(value, np.ndarray) and value == self.stop_signal: - raise StopIteration() - else: - return value \ No newline at end of file