diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 1b744ff85d3bd9e768c16cce5b236a949025aefb..f7af261c0afde3bd187eaaec48e4680f0badc8d0 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -1,38 +1,39 @@ import logging +import os import socket +import sys import threading -from threading import Thread, Event +from collections import deque +from copy import copy +from dataclasses import dataclass, field +from pathlib import Path from queue import Queue +from threading import Event, Thread from time import perf_counter -import sys -import os -from pathlib import Path -from dataclasses import dataclass, field -from copy import copy -from collections import deque import numpy as np import torch from nltk.tokenize import sent_tokenize from rich.console import Console from transformers import ( - AutoModelForCausalLM, - AutoModelForSpeechSeq2Seq, - AutoProcessor, - AutoTokenizer, - pipeline, - TextIteratorStreamer, - HfArgumentParser + AutoModelForCausalLM, + AutoModelForSpeechSeq2Seq, + AutoProcessor, + AutoTokenizer, + HfArgumentParser, + pipeline, + TextIteratorStreamer ) + from parler_tts import ( ParlerTTSForConditionalGeneration, - ParlerTTSStreamer, + ParlerTTSStreamer ) from utils import ( - VADIterator, + VADIterator, int2float, - next_power_of_2, + next_power_of_2 ) @@ -44,8 +45,10 @@ torch._inductor.config.fx_graph_cache = True # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS torch._dynamo.config.cache_size_limit = 15 + console = Console() + @dataclass class ModuleArguments: log_level: str = field( @@ -55,7 +58,12 @@ class ModuleArguments: } ) + class ThreadManager: + """ + Manages multiple threads used to execute given handler tasks. + """ + def __init__(self, handlers): self.handlers = handlers self.threads = [] @@ -72,7 +80,16 @@ class ThreadManager: for thread in self.threads: thread.join() + class BaseHandler: + """ + Base class for pipeline parts. Each part of the pipeline has an input and an output queue. + The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part. + To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue. + Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue. + The cleanup method handles stopping the handler, and b"END" is placed in the output queue. + """ + def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}): self.stop_event = stop_event self.queue_in = queue_in @@ -135,6 +152,10 @@ class SocketReceiverArguments: class SocketReceiver: + """ + Handles reception of the audio packets from the client. + """ + def __init__( self, stop_event, @@ -201,6 +222,10 @@ class SocketSenderArguments: class SocketSender: + """ + Handles sending generated audio packets to the clients. + """ + def __init__( self, stop_event, @@ -273,6 +298,11 @@ class VADHandlerArguments: class VADHandler(BaseHandler): + """ + Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed + to the following part. + """ + def setup( self, should_listen, @@ -284,11 +314,11 @@ class VADHandler(BaseHandler): speech_pad_ms=30, ): - self._should_listen = should_listen - self._sample_rate = sample_rate - self._min_silence_ms = min_silence_ms - self._min_speech_ms = min_speech_ms - self._max_speech_ms = max_speech_ms + self.should_listen = should_listen + self.sample_rate = sample_rate + self.min_silence_ms = min_silence_ms + self.min_speech_ms = min_speech_ms + self.max_speech_ms = max_speech_ms self.model, _ = torch.hub.load('snakers4/silero-vad', 'silero_vad') self.iterator = VADIterator( self.model, @@ -305,8 +335,8 @@ class VADHandler(BaseHandler): if vad_output is not None and len(vad_output) != 0: logger.debug("VAD: end of speech detected") array = torch.cat(vad_output).cpu().numpy() - duration_ms = len(array) / self._sample_rate * 1000 - if duration_ms < self._min_speech_ms or duration_ms > self._max_speech_ms: + duration_ms = len(array) / self.sample_rate * 1000 + if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms: logger.debug(f"audio input of duration: {len(array) / self._sample_rate}s, skipping") else: self._should_listen.clear() @@ -373,6 +403,10 @@ class WhisperSTTHandlerArguments: class WhisperSTTHandler(BaseHandler): + """ + Handles the Speech To Text generation using a Whisper model. + """ + def setup( self, model_name="distil-whisper/distil-large-v3", @@ -381,16 +415,17 @@ class WhisperSTTHandler(BaseHandler): compile_mode=None, gen_kwargs={} ): - self.compile_mode=compile_mode - self.processor = AutoProcessor.from_pretrained(model_name) self.device = device self.torch_dtype = getattr(torch, torch_dtype) + self.compile_mode=compile_mode + self.gen_kwargs = gen_kwargs + + self.processor = AutoProcessor.from_pretrained(model_name) self.model = AutoModelForSpeechSeq2Seq.from_pretrained( model_name, torch_dtype=self.torch_dtype, ).to(device) - self.gen_kwargs = gen_kwargs - + # compile if self.compile_mode: self.model.generation_config.cache_implementation = "static" @@ -402,20 +437,19 @@ class WhisperSTTHandler(BaseHandler): spoken_prompt, sampling_rate=16000, return_tensors="pt" ).input_features input_features = input_features.to(self.device, dtype=self.torch_dtype) + return input_features def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + # 2 warmup steps for no compile or compile mode with CUDA graphs capture n_steps = 1 if self.compile_mode == "default" else 2 - logger.info(f"Warming up {self.__class__.__name__}") dummy_input = torch.randn( (1, self.model.config.num_mel_bins, 3000), dtype=self.torch_dtype, device=self.device ) - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() if self.compile_mode not in (None, "default"): # generating more tokens than previously will trigger CUDA graphs capture # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation @@ -427,28 +461,35 @@ class WhisperSTTHandler(BaseHandler): else: warmup_gen_kwargs = self.gen_kwargs + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + torch.cuda.synchronize() start_event.record() for _ in range(n_steps): _ = self.model.generate(dummy_input, **warmup_gen_kwargs) end_event.record() torch.cuda.synchronize() + logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") def process(self, spoken_prompt): + logger.debug("infering whisper...") + global pipeline_start pipeline_start = perf_counter() - input_features = self.processor( - spoken_prompt, sampling_rate=16000, return_tensors="pt" - ).input_features - input_features = input_features.to(self.device, dtype=self.torch_dtype) - logger.debug("infering whisper...") + + input_features = self.prepare_model_inputs(spoken_prompt) pred_ids = self.model.generate(input_features, **self.gen_kwargs) pred_text = self.processor.batch_decode( - pred_ids, skip_special_tokens=True, + pred_ids, + skip_special_tokens=True, decode_with_timestamps=False )[0] + logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") + yield pred_text @@ -509,6 +550,10 @@ class LanguageModelHandlerArguments: class Chat: + """ + Handles the chat using a circular buffer to avoid OOM issues. + """ + def __init__(self, size): self.init_chat_message = None self.buffer = deque(maxlen=size) @@ -527,25 +572,30 @@ class Chat: class LanguageModelHandler(BaseHandler): + """ + Handles the language model part. + """ + def setup( self, model_name="microsoft/Phi-3-mini-4k-instruct", device="cuda", torch_dtype="float16", - chat_size=3, gen_kwargs={}, user_role="user", + chat_size=3, init_chat_role=None, init_chat_prompt="You are a helpful AI assistant.", ): + self.device = device self.torch_dtype = getattr(torch, torch_dtype) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch_dtype, trust_remote_code=True ).to(device) - self.device = device self.pipe = pipeline( "text-generation", model=self.model, @@ -556,6 +606,12 @@ class LanguageModelHandler(BaseHandler): skip_prompt=True, skip_special_tokens=True, ) + self.gen_kwargs = { + "streamer": self.streamer, + "return_full_text": False, + **gen_kwargs + } + self.chat = Chat(chat_size) if init_chat_role: if not init_chat_prompt: @@ -563,26 +619,12 @@ class LanguageModelHandler(BaseHandler): self.chat.init_chat( {"role": init_chat_role, "content": init_chat_prompt} ) - - self.gen_kwargs = { - "streamer": self.streamer, - "return_full_text": False, - **gen_kwargs - } self.user_role = user_role - - - self.warmup() def warmup(self): - # 2 warmup steps for no compile or compile mode with CUDA graphs capture - n_steps = 2 logger.info(f"Warming up {self.__class__.__name__}") - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() dummy_input_text = "Write me a poem about Machine Learning." dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] @@ -592,25 +634,33 @@ class LanguageModelHandler(BaseHandler): **self.gen_kwargs } + n_steps = 2 + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + torch.cuda.synchronize() start_event.record() for _ in range(n_steps): thread = Thread(target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs) thread.start() for _ in self.streamer: - pass - + pass end_event.record() torch.cuda.synchronize() + logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") def process(self, prompt): + logger.debug("infering language model...") + self.chat.append( {"role": self.user_role, "content": prompt} ) thread = Thread(target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs) thread.start() + generated_text, printable_text = "", "" - logger.debug("infering language model...") for new_text in self.streamer: generated_text += new_text printable_text += new_text @@ -618,9 +668,11 @@ class LanguageModelHandler(BaseHandler): if len(sentences) > 1: yield(sentences[0]) printable_text = new_text + self.chat.append( {"role": "assistant", "content": generated_text} ) + # don't forget last sentence yield printable_text @@ -689,37 +741,33 @@ class ParlerTTSHandler(BaseHandler): model_name="ylacombe/parler-tts-mini-jenny-30H", device="cuda", torch_dtype="float16", - max_prompt_pad_length=8, - gen_kwargs={}, compile_mode=None, + gen_kwargs={}, + max_prompt_pad_length=8, description=( "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " "She speaks very fast." ), play_steps_s=1 ): + self.should_listen = should_listen + self.device = device + self.torch_dtype = getattr(torch, torch_dtype) + self.gen_kwargs = gen_kwargs + self.compile_mode = compile_mode self.max_prompt_pad_length = max_prompt_pad_length - torch_dtype = getattr(torch, torch_dtype) - self._should_listen = should_listen + self.description = description + self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = ParlerTTSForConditionalGeneration.from_pretrained( model_name, - torch_dtype=torch_dtype + torch_dtype=self.torch_dtype ).to(device) - self.device = device - self.torch_dtype = torch_dtype - - self.description = description - self.gen_kwargs = gen_kwargs - - framerate = self.model.audio_encoder.config.frame_rate - self.play_steps = int(framerate * play_steps_s) framerate = self.model.audio_encoder.config.frame_rate self.play_steps = int(framerate * play_steps_s) - self.compile_mode = compile_mode if self.compile_mode not in (None, "default"): logger.warning("Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'") self.compile_mode = "default" @@ -727,6 +775,7 @@ class ParlerTTSHandler(BaseHandler): if self.compile_mode: self.model.generation_config.cache_implementation = "static" self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True) + self.warmup() def prepare_model_inputs( @@ -752,29 +801,40 @@ class ParlerTTSHandler(BaseHandler): "prompt_attention_mask": prompt_attention_mask, **self.gen_kwargs } + return gen_kwargs def warmup(self): - pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)] - for pad_length in pad_lengths[::-1]: - model_kwargs = self.prepare_model_inputs( - "dummy prompt", - max_length_prompt=pad_length, - pad=True - ) - # 2 warmup steps for modes that capture CUDA graphs - n_steps = 1 if self.compile_mode == "default" else 2 - - logger.info(f"Warming up length {pad_length} tokens...") - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start_event.record() + logger.info(f"Warming up {self.__class__.__name__}") + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # 2 warmup steps for no compile or compile mode with CUDA graphs capture + n_steps = 1 if self.compile_mode == "default" else 2 + + torch.cuda.synchronize() + start_event.record() + if self.compile_mode: + pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)] + for pad_length in pad_lengths[::-1]: + model_kwargs = self.prepare_model_inputs( + "dummy prompt", + max_length_prompt=pad_length, + pad=True + ) + for _ in range(n_steps): + _ = self.model.generate(**model_kwargs) + logger.info(f"Warmed up length {pad_length} tokens!") + else: + model_kwargs = self.prepare_model_inputs("dummy prompt") for _ in range(n_steps): - _ = self.model.generate(**model_kwargs) - end_event.record() - torch.cuda.synchronize() - logger.info(f"Warmed up! Compilation time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") + _ = self.model.generate(**model_kwargs) + + end_event.record() + torch.cuda.synchronize() + logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") + def process(self, llm_sentence): console.print(f"[green]ASSISTANT: {llm_sentence}") @@ -808,10 +868,14 @@ class ParlerTTSHandler(BaseHandler): audio_chunk = np.int16(audio_chunk * 32767) yield audio_chunk - self._should_listen.set() + self.should_listen.set() def prepare_args(args, prefix): + """ + Rename arguments by removing the prefix and prepares the gen_kwargs. + """ + gen_kwargs = {} for key in copy(args.__dict__): if key.startswith(prefix): @@ -860,6 +924,7 @@ def main(): parler_tts_handler_kwargs, ) = parser.parse_args_into_dataclasses() + # 1. Handle logger global logger logging.basicConfig( level=module_kwargs.log_level.upper(), @@ -871,12 +936,15 @@ def main(): if module_kwargs.log_level == "debug": torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) + # 2. Prepare each part's arguments prepare_args(whisper_stt_handler_kwargs, "stt") prepare_args(language_model_handler_kwargs, "lm") prepare_args(parler_tts_handler_kwargs, "tts") + # 3. Build the pipeline stop_event = Event() - should_listen = Event() + # used to stop putting received audio chunks in queue until all setences have been processed by the TTS + should_listen = Event() recv_audio_chunks_queue = Queue() send_audio_chunks_queue = Queue() spoken_prompt_queue = Queue() @@ -926,6 +994,7 @@ def main(): port=socket_sender_kwargs.send_port, ) + # 4. Run the pipeline try: pipeline_manager = ThreadManager([vad, tts, lm, stt, recv_handler, send_handler]) pipeline_manager.start()