diff --git a/LLM/language_model.py b/LLM/language_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5369c7350c52f9240fe064ffac6316db0196b5ec --- /dev/null +++ b/LLM/language_model.py @@ -0,0 +1,134 @@ +from threading import Thread +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + pipeline, + TextIteratorStreamer, +) +import torch + +from LLM.chat import Chat +from baseHandler import BaseHandler +from rich.console import Console +import logging +from nltk import sent_tokenize + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + + +class LanguageModelHandler(BaseHandler): + """ + Handles the language model part. + """ + + def setup( + self, + model_name="microsoft/Phi-3-mini-4k-instruct", + device="cuda", + torch_dtype="float16", + gen_kwargs={}, + user_role="user", + chat_size=1, + init_chat_role=None, + init_chat_prompt="You are a helpful AI assistant.", + ): + self.device = device + self.torch_dtype = getattr(torch, torch_dtype) + + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch_dtype, trust_remote_code=True + ).to(device) + self.pipe = pipeline( + "text-generation", model=self.model, tokenizer=self.tokenizer, device=device + ) + self.streamer = TextIteratorStreamer( + self.tokenizer, + skip_prompt=True, + skip_special_tokens=True, + ) + self.gen_kwargs = { + "streamer": self.streamer, + "return_full_text": False, + **gen_kwargs, + } + + self.chat = Chat(chat_size) + if init_chat_role: + if not init_chat_prompt: + raise ValueError( + "An initial promt needs to be specified when setting init_chat_role." + ) + self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) + self.user_role = user_role + + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + + dummy_input_text = "Write me a poem about Machine Learning." + dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] + warmup_gen_kwargs = { + "min_new_tokens": self.gen_kwargs["max_new_tokens"], + "max_new_tokens": self.gen_kwargs["max_new_tokens"], + **self.gen_kwargs, + } + + n_steps = 2 + + if self.device == "cuda": + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + + for _ in range(n_steps): + thread = Thread( + target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs + ) + thread.start() + for _ in self.streamer: + pass + + if self.device == "cuda": + end_event.record() + torch.cuda.synchronize() + + logger.info( + f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" + ) + + def process(self, prompt): + logger.debug("infering language model...") + + self.chat.append({"role": self.user_role, "content": prompt}) + thread = Thread( + target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs + ) + thread.start() + if self.device == "mps": + generated_text = "" + for new_text in self.streamer: + generated_text += new_text + printable_text = generated_text + torch.mps.empty_cache() + else: + generated_text, printable_text = "", "" + for new_text in self.streamer: + generated_text += new_text + printable_text += new_text + sentences = sent_tokenize(printable_text) + if len(sentences) > 1: + yield (sentences[0]) + printable_text = new_text + + self.chat.append({"role": "assistant", "content": generated_text}) + + # don't forget last sentence + yield printable_text diff --git a/LLM/mlx_lm.py b/LLM/mlx_language_model.py similarity index 95% rename from LLM/mlx_lm.py rename to LLM/mlx_language_model.py index 0192afb36e0fcfe5a166d7589a9fe6f38e8cfd05..65c3b2da9f360c7524df71e4d1d2ee599391e958 100644 --- a/LLM/mlx_lm.py +++ b/LLM/mlx_language_model.py @@ -66,13 +66,15 @@ class MLXLanguageModelHandler(BaseHandler): logger.debug("infering language model...") self.chat.append({"role": self.user_role, "content": prompt}) - + # Remove system messages if using a Gemma model if "gemma" in self.model_name.lower(): - chat_messages = [msg for msg in self.chat.to_list() if msg["role"] != "system"] + chat_messages = [ + msg for msg in self.chat.to_list() if msg["role"] != "system" + ] else: chat_messages = self.chat.to_list() - + prompt = self.tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True ) diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..4f13a95eb57fa12f31fcb117e76512c7cb728742 --- /dev/null +++ b/STT/whisper_stt_handler.py @@ -0,0 +1,113 @@ +from time import perf_counter +from transformers import ( + AutoModelForSpeechSeq2Seq, + AutoProcessor, +) +import torch + +from baseHandler import BaseHandler +from rich.console import Console +import logging + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + + +class WhisperSTTHandler(BaseHandler): + """ + Handles the Speech To Text generation using a Whisper model. + """ + + def setup( + self, + model_name="distil-whisper/distil-large-v3", + device="cuda", + torch_dtype="float16", + compile_mode=None, + gen_kwargs={}, + ): + self.device = device + self.torch_dtype = getattr(torch, torch_dtype) + self.compile_mode = compile_mode + self.gen_kwargs = gen_kwargs + + self.processor = AutoProcessor.from_pretrained(model_name) + self.model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_name, + torch_dtype=self.torch_dtype, + ).to(device) + + # compile + if self.compile_mode: + self.model.generation_config.cache_implementation = "static" + self.model.forward = torch.compile( + self.model.forward, mode=self.compile_mode, fullgraph=True + ) + self.warmup() + + def prepare_model_inputs(self, spoken_prompt): + input_features = self.processor( + spoken_prompt, sampling_rate=16000, return_tensors="pt" + ).input_features + input_features = input_features.to(self.device, dtype=self.torch_dtype) + + return input_features + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + + # 2 warmup steps for no compile or compile mode with CUDA graphs capture + n_steps = 1 if self.compile_mode == "default" else 2 + dummy_input = torch.randn( + (1, self.model.config.num_mel_bins, 3000), + dtype=self.torch_dtype, + device=self.device, + ) + if self.compile_mode not in (None, "default"): + # generating more tokens than previously will trigger CUDA graphs capture + # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation + warmup_gen_kwargs = { + "min_new_tokens": self.gen_kwargs["max_new_tokens"], + "max_new_tokens": self.gen_kwargs["max_new_tokens"], + **self.gen_kwargs, + } + else: + warmup_gen_kwargs = self.gen_kwargs + + if self.device == "cuda": + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + + for _ in range(n_steps): + _ = self.model.generate(dummy_input, **warmup_gen_kwargs) + + if self.device == "cuda": + end_event.record() + torch.cuda.synchronize() + + logger.info( + f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" + ) + + def process(self, spoken_prompt): + logger.debug("infering whisper...") + + global pipeline_start + pipeline_start = perf_counter() + + input_features = self.prepare_model_inputs(spoken_prompt) + pred_ids = self.model.generate(input_features, **self.gen_kwargs) + pred_text = self.processor.batch_decode( + pred_ids, skip_special_tokens=True, decode_with_timestamps=False + )[0] + + logger.debug("finished whisper inference") + console.print(f"[yellow]USER: {pred_text}") + + yield pred_text diff --git a/TTS/melotts.py b/TTS/melo_handler.py similarity index 98% rename from TTS/melotts.py rename to TTS/melo_handler.py index fad87d93e1e8bd77f0c812653de8ed8edd33025e..9897679a02fee6ffaadadb521481cf69915e732c 100644 --- a/TTS/melotts.py +++ b/TTS/melo_handler.py @@ -24,7 +24,6 @@ class MeloTTSHandler(BaseHandler): gen_kwargs={}, # Unused blocksize=512, ): - print(device) self.should_listen = should_listen self.device = device self.model = TTS(language=language, device=device) diff --git a/TTS/parler_handler.py b/TTS/parler_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..efeb5a84145b65a65bdb7826344a0370a680b111 --- /dev/null +++ b/TTS/parler_handler.py @@ -0,0 +1,181 @@ +from threading import Thread +from time import perf_counter +from baseHandler import BaseHandler +import numpy as np +import torch +from transformers import ( + AutoTokenizer, +) +from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer +import librosa +import logging +from rich.console import Console +from utils.utils import next_power_of_2 + +torch._inductor.config.fx_graph_cache = True +# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS +torch._dynamo.config.cache_size_limit = 15 + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + + +class ParlerTTSHandler(BaseHandler): + def setup( + self, + should_listen, + model_name="ylacombe/parler-tts-mini-jenny-30H", + device="cuda", + torch_dtype="float16", + compile_mode=None, + gen_kwargs={}, + max_prompt_pad_length=8, + description=( + "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " + "She speaks very fast." + ), + play_steps_s=1, + blocksize=512, + ): + self.should_listen = should_listen + self.device = device + self.torch_dtype = getattr(torch, torch_dtype) + self.gen_kwargs = gen_kwargs + self.compile_mode = compile_mode + self.max_prompt_pad_length = max_prompt_pad_length + self.description = description + + self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) + self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = ParlerTTSForConditionalGeneration.from_pretrained( + model_name, torch_dtype=self.torch_dtype + ).to(device) + + framerate = self.model.audio_encoder.config.frame_rate + self.play_steps = int(framerate * play_steps_s) + self.blocksize = blocksize + + if self.compile_mode not in (None, "default"): + logger.warning( + "Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'" + ) + self.compile_mode = "default" + + if self.compile_mode: + self.model.generation_config.cache_implementation = "static" + self.model.forward = torch.compile( + self.model.forward, mode=self.compile_mode, fullgraph=True + ) + + self.warmup() + + def prepare_model_inputs( + self, + prompt, + max_length_prompt=50, + pad=False, + ): + pad_args_prompt = ( + {"padding": "max_length", "max_length": max_length_prompt} if pad else {} + ) + + tokenized_description = self.description_tokenizer( + self.description, return_tensors="pt" + ) + input_ids = tokenized_description.input_ids.to(self.device) + attention_mask = tokenized_description.attention_mask.to(self.device) + + tokenized_prompt = self.prompt_tokenizer( + prompt, return_tensors="pt", **pad_args_prompt + ) + prompt_input_ids = tokenized_prompt.input_ids.to(self.device) + prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device) + + gen_kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "prompt_input_ids": prompt_input_ids, + "prompt_attention_mask": prompt_attention_mask, + **self.gen_kwargs, + } + + return gen_kwargs + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + + if self.device == "cuda": + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # 2 warmup steps for no compile or compile mode with CUDA graphs capture + n_steps = 1 if self.compile_mode == "default" else 2 + + if self.device == "cuda": + torch.cuda.synchronize() + start_event.record() + if self.compile_mode: + pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)] + for pad_length in pad_lengths[::-1]: + model_kwargs = self.prepare_model_inputs( + "dummy prompt", max_length_prompt=pad_length, pad=True + ) + for _ in range(n_steps): + _ = self.model.generate(**model_kwargs) + logger.info(f"Warmed up length {pad_length} tokens!") + else: + model_kwargs = self.prepare_model_inputs("dummy prompt") + for _ in range(n_steps): + _ = self.model.generate(**model_kwargs) + + if self.device == "cuda": + end_event.record() + torch.cuda.synchronize() + logger.info( + f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" + ) + + def process(self, llm_sentence): + console.print(f"[green]ASSISTANT: {llm_sentence}") + nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids) + + pad_args = {} + if self.compile_mode: + # pad to closest upper power of two + pad_length = next_power_of_2(nb_tokens) + logger.debug(f"padding to {pad_length}") + pad_args["pad"] = True + pad_args["max_length_prompt"] = pad_length + + tts_gen_kwargs = self.prepare_model_inputs( + llm_sentence, + **pad_args, + ) + + streamer = ParlerTTSStreamer( + self.model, device=self.device, play_steps=self.play_steps + ) + tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs} + torch.manual_seed(0) + thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs) + thread.start() + + for i, audio_chunk in enumerate(streamer): + global pipeline_start + if i == 0 and "pipeline_start" in globals(): + logger.info( + f"Time to first audio: {perf_counter() - pipeline_start:.3f}" + ) + audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) + audio_chunk = (audio_chunk * 32768).astype(np.int16) + for i in range(0, len(audio_chunk), self.blocksize): + yield np.pad( + audio_chunk[i : i + self.blocksize], + (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), + ) + + self.should_listen.set() diff --git a/VAD/vad_handler.py b/VAD/vad_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..e2da03f7b1cd612489d98f49ab56ea9ad45b942c --- /dev/null +++ b/VAD/vad_handler.py @@ -0,0 +1,64 @@ +from VAD.vad_iterator import VADIterator +from baseHandler import BaseHandler +import numpy as np +import torch +from rich.console import Console + +from utils.utils import int2float + +import logging + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + + +class VADHandler(BaseHandler): + """ + Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed + to the following part. + """ + + def setup( + self, + should_listen, + thresh=0.3, + sample_rate=16000, + min_silence_ms=1000, + min_speech_ms=500, + max_speech_ms=float("inf"), + speech_pad_ms=30, + ): + self.should_listen = should_listen + self.sample_rate = sample_rate + self.min_silence_ms = min_silence_ms + self.min_speech_ms = min_speech_ms + self.max_speech_ms = max_speech_ms + self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad") + self.iterator = VADIterator( + self.model, + threshold=thresh, + sampling_rate=sample_rate, + min_silence_duration_ms=min_silence_ms, + speech_pad_ms=speech_pad_ms, + ) + + def process(self, audio_chunk): + audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16) + audio_float32 = int2float(audio_int16) + vad_output = self.iterator(torch.from_numpy(audio_float32)) + if vad_output is not None and len(vad_output) != 0: + logger.debug("VAD: end of speech detected") + array = torch.cat(vad_output).cpu().numpy() + duration_ms = len(array) / self.sample_rate * 1000 + if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms: + logger.debug( + f"audio input of duration: {len(array) / self.sample_rate}s, skipping" + ) + else: + self.should_listen.clear() + logger.debug("Stop listening") + yield array diff --git a/utils.py b/VAD/vad_iterator.py similarity index 89% rename from utils.py rename to VAD/vad_iterator.py index 3e2e9bc1d98b5e3b350a7fb865b0b7414af9d3aa..bd272f1dd7bfb432b5323e8906730e158ad5c2ff 100644 --- a/utils.py +++ b/VAD/vad_iterator.py @@ -1,24 +1,6 @@ -import numpy as np import torch -def next_power_of_2(x): - return 1 if x == 0 else 2 ** (x - 1).bit_length() - - -def int2float(sound): - """ - Taken from https://github.com/snakers4/silero-vad - """ - - abs_max = np.abs(sound).max() - sound = sound.astype("float32") - if abs_max > 0: - sound *= 1 / 32768 - sound = sound.squeeze() # depends on the use case - return sound - - class VADIterator: def __init__( self, diff --git a/local_audio_streamer.py b/connections/local_audio_streamer.py similarity index 100% rename from local_audio_streamer.py rename to connections/local_audio_streamer.py diff --git a/connections/socket_receiver.py b/connections/socket_receiver.py new file mode 100644 index 0000000000000000000000000000000000000000..19bda8da081da5a949c1bd046f84c4fa43b56403 --- /dev/null +++ b/connections/socket_receiver.py @@ -0,0 +1,63 @@ +import socket +from rich.console import Console +import logging + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + + +class SocketReceiver: + """ + Handles reception of the audio packets from the client. + """ + + def __init__( + self, + stop_event, + queue_out, + should_listen, + host="0.0.0.0", + port=12345, + chunk_size=1024, + ): + self.stop_event = stop_event + self.queue_out = queue_out + self.should_listen = should_listen + self.chunk_size = chunk_size + self.host = host + self.port = port + + def receive_full_chunk(self, conn, chunk_size): + data = b"" + while len(data) < chunk_size: + packet = conn.recv(chunk_size - len(data)) + if not packet: + # connection closed + return None + data += packet + return data + + def run(self): + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.bind((self.host, self.port)) + self.socket.listen(1) + logger.info("Receiver waiting to be connected...") + self.conn, _ = self.socket.accept() + logger.info("receiver connected") + + self.should_listen.set() + while not self.stop_event.is_set(): + audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size) + if audio_chunk is None: + # connection closed + self.queue_out.put(b"END") + break + if self.should_listen.is_set(): + self.queue_out.put(audio_chunk) + self.conn.close() + logger.info("Receiver closed") diff --git a/connections/socket_sender.py b/connections/socket_sender.py new file mode 100644 index 0000000000000000000000000000000000000000..587343b8204e4e79af29346668a08ab4f6d6486d --- /dev/null +++ b/connections/socket_sender.py @@ -0,0 +1,39 @@ +import socket +from rich.console import Console +import logging + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + + +class SocketSender: + """ + Handles sending generated audio packets to the clients. + """ + + def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346): + self.stop_event = stop_event + self.queue_in = queue_in + self.host = host + self.port = port + + def run(self): + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.bind((self.host, self.port)) + self.socket.listen(1) + logger.info("Sender waiting to be connected...") + self.conn, _ = self.socket.accept() + logger.info("sender connected") + + while not self.stop_event.is_set(): + audio_chunk = self.queue_in.get() + self.conn.sendall(audio_chunk) + if isinstance(audio_chunk, bytes) and audio_chunk == b"END": + break + self.conn.close() + logger.info("Sender closed") diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 9ed41eb75798106f5be626f7c1c12d4a6169a2b1..7a9e204ae16bdcde3992ecfa09bda7e73dabae95 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -1,43 +1,32 @@ import logging import os -import socket import sys -import threading from copy import copy from pathlib import Path from queue import Queue -from threading import Event, Thread -from time import perf_counter +from threading import Event from typing import Optional from sys import platform +from VAD.vad_handler import VADHandler from arguments_classes.language_model_arguments import LanguageModelHandlerArguments -from arguments_classes.mlx_language_model_arguments import MLXLanguageModelHandlerArguments +from arguments_classes.mlx_language_model_arguments import ( + MLXLanguageModelHandlerArguments, +) from arguments_classes.module_arguments import ModuleArguments from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments from arguments_classes.socket_receiver_arguments import SocketReceiverArguments from arguments_classes.socket_sender_arguments import SocketSenderArguments from arguments_classes.vad_arguments import VADHandlerArguments from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments -from baseHandler import BaseHandler from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments -import numpy as np import torch import nltk -from nltk.tokenize import sent_tokenize from rich.console import Console from transformers import ( - AutoModelForCausalLM, - AutoModelForSpeechSeq2Seq, - AutoProcessor, - AutoTokenizer, HfArgumentParser, - pipeline, - TextIteratorStreamer, ) -from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer -import librosa -from utils import VADIterator, int2float, next_power_of_2 +from utils.thread_manager import ThreadManager # Ensure that the necessary NLTK resources are available try: @@ -58,550 +47,6 @@ os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") console = Console() -class ThreadManager: - """ - Manages multiple threads used to execute given handler tasks. - """ - - def __init__(self, handlers): - self.handlers = handlers - self.threads = [] - - def start(self): - for handler in self.handlers: - thread = threading.Thread(target=handler.run) - self.threads.append(thread) - thread.start() - - def stop(self): - for handler in self.handlers: - handler.stop_event.set() - for thread in self.threads: - thread.join() - - -class SocketReceiver: - """ - Handles reception of the audio packets from the client. - """ - - def __init__( - self, - stop_event, - queue_out, - should_listen, - host="0.0.0.0", - port=12345, - chunk_size=1024, - ): - self.stop_event = stop_event - self.queue_out = queue_out - self.should_listen = should_listen - self.chunk_size = chunk_size - self.host = host - self.port = port - - def receive_full_chunk(self, conn, chunk_size): - data = b"" - while len(data) < chunk_size: - packet = conn.recv(chunk_size - len(data)) - if not packet: - # connection closed - return None - data += packet - return data - - def run(self): - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind((self.host, self.port)) - self.socket.listen(1) - logger.info("Receiver waiting to be connected...") - self.conn, _ = self.socket.accept() - logger.info("receiver connected") - - self.should_listen.set() - while not self.stop_event.is_set(): - audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size) - if audio_chunk is None: - # connection closed - self.queue_out.put(b"END") - break - if self.should_listen.is_set(): - self.queue_out.put(audio_chunk) - self.conn.close() - logger.info("Receiver closed") - - -class SocketSender: - """ - Handles sending generated audio packets to the clients. - """ - - def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346): - self.stop_event = stop_event - self.queue_in = queue_in - self.host = host - self.port = port - - def run(self): - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind((self.host, self.port)) - self.socket.listen(1) - logger.info("Sender waiting to be connected...") - self.conn, _ = self.socket.accept() - logger.info("sender connected") - - while not self.stop_event.is_set(): - audio_chunk = self.queue_in.get() - self.conn.sendall(audio_chunk) - if isinstance(audio_chunk, bytes) and audio_chunk == b"END": - break - self.conn.close() - logger.info("Sender closed") - - -class VADHandler(BaseHandler): - """ - Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed - to the following part. - """ - - def setup( - self, - should_listen, - thresh=0.3, - sample_rate=16000, - min_silence_ms=1000, - min_speech_ms=500, - max_speech_ms=float("inf"), - speech_pad_ms=30, - ): - self.should_listen = should_listen - self.sample_rate = sample_rate - self.min_silence_ms = min_silence_ms - self.min_speech_ms = min_speech_ms - self.max_speech_ms = max_speech_ms - self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad") - self.iterator = VADIterator( - self.model, - threshold=thresh, - sampling_rate=sample_rate, - min_silence_duration_ms=min_silence_ms, - speech_pad_ms=speech_pad_ms, - ) - - def process(self, audio_chunk): - audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16) - audio_float32 = int2float(audio_int16) - vad_output = self.iterator(torch.from_numpy(audio_float32)) - if vad_output is not None and len(vad_output) != 0: - logger.debug("VAD: end of speech detected") - array = torch.cat(vad_output).cpu().numpy() - duration_ms = len(array) / self.sample_rate * 1000 - if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms: - logger.debug( - f"audio input of duration: {len(array) / self.sample_rate}s, skipping" - ) - else: - self.should_listen.clear() - logger.debug("Stop listening") - yield array - - -class WhisperSTTHandler(BaseHandler): - """ - Handles the Speech To Text generation using a Whisper model. - """ - - def setup( - self, - model_name="distil-whisper/distil-large-v3", - device="cuda", - torch_dtype="float16", - compile_mode=None, - gen_kwargs={}, - ): - self.device = device - self.torch_dtype = getattr(torch, torch_dtype) - self.compile_mode = compile_mode - self.gen_kwargs = gen_kwargs - - self.processor = AutoProcessor.from_pretrained(model_name) - self.model = AutoModelForSpeechSeq2Seq.from_pretrained( - model_name, - torch_dtype=self.torch_dtype, - ).to(device) - - # compile - if self.compile_mode: - self.model.generation_config.cache_implementation = "static" - self.model.forward = torch.compile( - self.model.forward, mode=self.compile_mode, fullgraph=True - ) - self.warmup() - - def prepare_model_inputs(self, spoken_prompt): - input_features = self.processor( - spoken_prompt, sampling_rate=16000, return_tensors="pt" - ).input_features - input_features = input_features.to(self.device, dtype=self.torch_dtype) - - return input_features - - def warmup(self): - logger.info(f"Warming up {self.__class__.__name__}") - - # 2 warmup steps for no compile or compile mode with CUDA graphs capture - n_steps = 1 if self.compile_mode == "default" else 2 - dummy_input = torch.randn( - (1, self.model.config.num_mel_bins, 3000), - dtype=self.torch_dtype, - device=self.device, - ) - if self.compile_mode not in (None, "default"): - # generating more tokens than previously will trigger CUDA graphs capture - # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation - warmup_gen_kwargs = { - "min_new_tokens": self.gen_kwargs["max_new_tokens"], - "max_new_tokens": self.gen_kwargs["max_new_tokens"], - **self.gen_kwargs, - } - else: - warmup_gen_kwargs = self.gen_kwargs - - if self.device == "cuda": - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start_event.record() - - for _ in range(n_steps): - _ = self.model.generate(dummy_input, **warmup_gen_kwargs) - - if self.device == "cuda": - end_event.record() - torch.cuda.synchronize() - - logger.info( - f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" - ) - - def process(self, spoken_prompt): - logger.debug("infering whisper...") - - global pipeline_start - pipeline_start = perf_counter() - - input_features = self.prepare_model_inputs(spoken_prompt) - pred_ids = self.model.generate(input_features, **self.gen_kwargs) - pred_text = self.processor.batch_decode( - pred_ids, skip_special_tokens=True, decode_with_timestamps=False - )[0] - - logger.debug("finished whisper inference") - console.print(f"[yellow]USER: {pred_text}") - - yield pred_text - - -class Chat: - """ - Handles the chat using to avoid OOM issues. - """ - - def __init__(self, size): - self.size = size - self.init_chat_message = None - # maxlen is necessary pair, since a each new step we add an prompt and assitant answer - self.buffer = [] - - def append(self, item): - self.buffer.append(item) - if len(self.buffer) == 2 * (self.size + 1): - self.buffer.pop(0) - self.buffer.pop(0) - - def init_chat(self, init_chat_message): - self.init_chat_message = init_chat_message - - def to_list(self): - if self.init_chat_message: - return [self.init_chat_message] + self.buffer - else: - return self.buffer - - -class LanguageModelHandler(BaseHandler): - """ - Handles the language model part. - """ - - def setup( - self, - model_name="microsoft/Phi-3-mini-4k-instruct", - device="cuda", - torch_dtype="float16", - gen_kwargs={}, - user_role="user", - chat_size=1, - init_chat_role=None, - init_chat_prompt="You are a helpful AI assistant.", - ): - self.device = device - self.torch_dtype = getattr(torch, torch_dtype) - - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype=torch_dtype, trust_remote_code=True - ).to(device) - self.pipe = pipeline( - "text-generation", model=self.model, tokenizer=self.tokenizer, device=device - ) - self.streamer = TextIteratorStreamer( - self.tokenizer, - skip_prompt=True, - skip_special_tokens=True, - ) - self.gen_kwargs = { - "streamer": self.streamer, - "return_full_text": False, - **gen_kwargs, - } - - self.chat = Chat(chat_size) - if init_chat_role: - if not init_chat_prompt: - raise ValueError( - "An initial promt needs to be specified when setting init_chat_role." - ) - self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) - self.user_role = user_role - - self.warmup() - - def warmup(self): - logger.info(f"Warming up {self.__class__.__name__}") - - dummy_input_text = "Write me a poem about Machine Learning." - dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] - warmup_gen_kwargs = { - "min_new_tokens": self.gen_kwargs["max_new_tokens"], - "max_new_tokens": self.gen_kwargs["max_new_tokens"], - **self.gen_kwargs, - } - - n_steps = 2 - - if self.device == "cuda": - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start_event.record() - - for _ in range(n_steps): - thread = Thread( - target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs - ) - thread.start() - for _ in self.streamer: - pass - - if self.device == "cuda": - end_event.record() - torch.cuda.synchronize() - - logger.info( - f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" - ) - - def process(self, prompt): - logger.debug("infering language model...") - - self.chat.append({"role": self.user_role, "content": prompt}) - thread = Thread( - target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs - ) - thread.start() - if self.device == "mps": - generated_text = "" - for new_text in self.streamer: - generated_text += new_text - printable_text = generated_text - torch.mps.empty_cache() - else: - generated_text, printable_text = "", "" - for new_text in self.streamer: - generated_text += new_text - printable_text += new_text - sentences = sent_tokenize(printable_text) - if len(sentences) > 1: - yield (sentences[0]) - printable_text = new_text - - self.chat.append({"role": "assistant", "content": generated_text}) - - # don't forget last sentence - yield printable_text - - -class ParlerTTSHandler(BaseHandler): - def setup( - self, - should_listen, - model_name="ylacombe/parler-tts-mini-jenny-30H", - device="cuda", - torch_dtype="float16", - compile_mode=None, - gen_kwargs={}, - max_prompt_pad_length=8, - description=( - "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " - "She speaks very fast." - ), - play_steps_s=1, - blocksize=512, - ): - self.should_listen = should_listen - self.device = device - self.torch_dtype = getattr(torch, torch_dtype) - self.gen_kwargs = gen_kwargs - self.compile_mode = compile_mode - self.max_prompt_pad_length = max_prompt_pad_length - self.description = description - - self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) - self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = ParlerTTSForConditionalGeneration.from_pretrained( - model_name, torch_dtype=self.torch_dtype - ).to(device) - - framerate = self.model.audio_encoder.config.frame_rate - self.play_steps = int(framerate * play_steps_s) - self.blocksize = blocksize - - if self.compile_mode not in (None, "default"): - logger.warning( - "Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'" - ) - self.compile_mode = "default" - - if self.compile_mode: - self.model.generation_config.cache_implementation = "static" - self.model.forward = torch.compile( - self.model.forward, mode=self.compile_mode, fullgraph=True - ) - - self.warmup() - - def prepare_model_inputs( - self, - prompt, - max_length_prompt=50, - pad=False, - ): - pad_args_prompt = ( - {"padding": "max_length", "max_length": max_length_prompt} if pad else {} - ) - - tokenized_description = self.description_tokenizer( - self.description, return_tensors="pt" - ) - input_ids = tokenized_description.input_ids.to(self.device) - attention_mask = tokenized_description.attention_mask.to(self.device) - - tokenized_prompt = self.prompt_tokenizer( - prompt, return_tensors="pt", **pad_args_prompt - ) - prompt_input_ids = tokenized_prompt.input_ids.to(self.device) - prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device) - - gen_kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "prompt_input_ids": prompt_input_ids, - "prompt_attention_mask": prompt_attention_mask, - **self.gen_kwargs, - } - - return gen_kwargs - - def warmup(self): - logger.info(f"Warming up {self.__class__.__name__}") - - if self.device == "cuda": - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # 2 warmup steps for no compile or compile mode with CUDA graphs capture - n_steps = 1 if self.compile_mode == "default" else 2 - - if self.device == "cuda": - torch.cuda.synchronize() - start_event.record() - if self.compile_mode: - pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)] - for pad_length in pad_lengths[::-1]: - model_kwargs = self.prepare_model_inputs( - "dummy prompt", max_length_prompt=pad_length, pad=True - ) - for _ in range(n_steps): - _ = self.model.generate(**model_kwargs) - logger.info(f"Warmed up length {pad_length} tokens!") - else: - model_kwargs = self.prepare_model_inputs("dummy prompt") - for _ in range(n_steps): - _ = self.model.generate(**model_kwargs) - - if self.device == "cuda": - end_event.record() - torch.cuda.synchronize() - logger.info( - f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" - ) - - def process(self, llm_sentence): - console.print(f"[green]ASSISTANT: {llm_sentence}") - nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids) - - pad_args = {} - if self.compile_mode: - # pad to closest upper power of two - pad_length = next_power_of_2(nb_tokens) - logger.debug(f"padding to {pad_length}") - pad_args["pad"] = True - pad_args["max_length_prompt"] = pad_length - - tts_gen_kwargs = self.prepare_model_inputs( - llm_sentence, - **pad_args, - ) - - streamer = ParlerTTSStreamer( - self.model, device=self.device, play_steps=self.play_steps - ) - tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs} - torch.manual_seed(0) - thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs) - thread.start() - - for i, audio_chunk in enumerate(streamer): - if i == 0 and "pipeline_start" in globals(): - logger.info( - f"Time to first audio: {perf_counter() - pipeline_start:.3f}" - ) - audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) - audio_chunk = (audio_chunk * 32768).astype(np.int16) - for i in range(0, len(audio_chunk), self.blocksize): - yield np.pad( - audio_chunk[i : i + self.blocksize], - (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), - ) - - self.should_listen.set() - - def prepare_args(args, prefix): """ Rename arguments by removing the prefix and prepares the gen_kwargs. @@ -745,7 +190,7 @@ def main(): lm_response_queue = Queue() if module_kwargs.mode == "local": - from local_audio_streamer import LocalAudioStreamer + from connections.local_audio_streamer import LocalAudioStreamer local_audio_streamer = LocalAudioStreamer( input_queue=recv_audio_chunks_queue, output_queue=send_audio_chunks_queue @@ -753,6 +198,9 @@ def main(): comms_handlers = [local_audio_streamer] should_listen.set() else: + from connections.socket_receiver import SocketReceiver + from connections.socket_sender import SocketSender + comms_handlers = [ SocketReceiver( stop_event, @@ -778,6 +226,8 @@ def main(): setup_kwargs=vars(vad_handler_kwargs), ) if module_kwargs.stt == "whisper": + from STT.whisper_stt_handler import WhisperSTTHandler + stt = WhisperSTTHandler( stop_event, queue_in=spoken_prompt_queue, @@ -786,6 +236,7 @@ def main(): ) elif module_kwargs.stt == "whisper-mlx": from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler + stt = LightningWhisperSTTHandler( stop_event, queue_in=spoken_prompt_queue, @@ -795,6 +246,8 @@ def main(): else: raise ValueError("The STT should be either whisper or whisper-mlx") if module_kwargs.llm == "transformers": + from LLM.language_model import LanguageModelHandler + lm = LanguageModelHandler( stop_event, queue_in=text_prompt_queue, @@ -802,7 +255,8 @@ def main(): setup_kwargs=vars(language_model_handler_kwargs), ) elif module_kwargs.llm == "mlx-lm": - from LLM.mlx_lm import MLXLanguageModelHandler + from LLM.mlx_language_model import MLXLanguageModelHandler + lm = MLXLanguageModelHandler( stop_event, queue_in=text_prompt_queue, @@ -812,9 +266,8 @@ def main(): else: raise ValueError("The LLM should be either transformers or mlx-lm") if module_kwargs.tts == "parler": - torch._inductor.config.fx_graph_cache = True - # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS - torch._dynamo.config.cache_size_limit = 15 + from TTS.parler_handler import ParlerTTSHandler + tts = ParlerTTSHandler( stop_event, queue_in=lm_response_queue, @@ -825,7 +278,7 @@ def main(): elif module_kwargs.tts == "melo": try: - from TTS.melotts import MeloTTSHandler + from TTS.melo_handler import MeloTTSHandler except RuntimeError as e: logger.error( "Error importing MeloTTSHandler. You might need to run: python -m unidic download" diff --git a/utils/thread_manager.py b/utils/thread_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1ca4a074833cb3953ddfa63e71822e26b23143 --- /dev/null +++ b/utils/thread_manager.py @@ -0,0 +1,23 @@ +import threading + + +class ThreadManager: + """ + Manages multiple threads used to execute given handler tasks. + """ + + def __init__(self, handlers): + self.handlers = handlers + self.threads = [] + + def start(self): + for handler in self.handlers: + thread = threading.Thread(target=handler.run) + self.threads.append(thread) + thread.start() + + def stop(self): + for handler in self.handlers: + handler.stop_event.set() + for thread in self.threads: + thread.join() diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac399486a1502fddbec823d2dccbb23f0558be14 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,18 @@ +import numpy as np + + +def next_power_of_2(x): + return 1 if x == 0 else 2 ** (x - 1).bit_length() + + +def int2float(sound): + """ + Taken from https://github.com/snakers4/silero-vad + """ + + abs_max = np.abs(sound).max() + sound = sound.astype("float32") + if abs_max > 0: + sound *= 1 / 32768 + sound = sound.squeeze() # depends on the use case + return sound