diff --git a/LLM/chat.py b/LLM/chat.py index 4245830f02b473223cf143f37f636abe8c310ae8..bc8ac4fbc266d084f2ffc636564c457162a130f1 100644 --- a/LLM/chat.py +++ b/LLM/chat.py @@ -1,6 +1,3 @@ - - - class Chat: """ Handles the chat using to avoid OOM issues. diff --git a/LLM/mlx_lm.py b/LLM/mlx_lm.py index ff63b711a7bd3105bf6af1556d9adf21fa98370f..a772e3a34a727c3656f87bff532f012bc915c43a 100644 --- a/LLM/mlx_lm.py +++ b/LLM/mlx_lm.py @@ -4,6 +4,7 @@ from baseHandler import BaseHandler from mlx_lm import load, stream_generate, generate from rich.console import Console import torch + logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) @@ -11,6 +12,7 @@ logger = logging.getLogger(__name__) console = Console() + class MLXLanguageModelHandler(BaseHandler): """ Handles the language model part. @@ -28,7 +30,7 @@ class MLXLanguageModelHandler(BaseHandler): init_chat_prompt="You are a helpful AI assistant.", ): self.model_name = model_name - model_id = 'microsoft/Phi-3-mini-4k-instruct' + model_id = "microsoft/Phi-3-mini-4k-instruct" self.model, self.tokenizer = load(model_id) self.gen_kwargs = gen_kwargs @@ -48,28 +50,40 @@ class MLXLanguageModelHandler(BaseHandler): dummy_input_text = "Write me a poem about Machine Learning." dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] - + n_steps = 2 for _ in range(n_steps): prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False) - generate(self.model, self.tokenizer, prompt=prompt, max_tokens=self.gen_kwargs["max_new_tokens"], verbose=False) - + generate( + self.model, + self.tokenizer, + prompt=prompt, + max_tokens=self.gen_kwargs["max_new_tokens"], + verbose=False, + ) def process(self, prompt): logger.debug("infering language model...") self.chat.append({"role": self.user_role, "content": prompt}) - prompt = self.tokenizer.apply_chat_template(self.chat.to_list(), tokenize=False, add_generation_prompt=True) + prompt = self.tokenizer.apply_chat_template( + self.chat.to_list(), tokenize=False, add_generation_prompt=True + ) output = "" curr_output = "" - for t in stream_generate(self.model, self.tokenizer, prompt, max_tokens=self.gen_kwargs["max_new_tokens"]): + for t in stream_generate( + self.model, + self.tokenizer, + prompt, + max_tokens=self.gen_kwargs["max_new_tokens"], + ): output += t curr_output += t - if curr_output.endswith(('.', '?', '!', '<|end|>')): - yield curr_output.replace('<|end|>', '') + if curr_output.endswith((".", "?", "!", "<|end|>")): + yield curr_output.replace("<|end|>", "") curr_output = "" - generated_text = output.replace('<|end|>', '') + generated_text = output.replace("<|end|>", "") torch.mps.empty_cache() self.chat.append({"role": "assistant", "content": generated_text}) diff --git a/STT/lightning_whisper_mlx_handler.py b/STT/lightning_whisper_mlx_handler.py index 53709025c7438c559e06a098b70c697922eef6a1..be770ac6ce26bda98bc7d83b2b470c24d6d4140f 100644 --- a/STT/lightning_whisper_mlx_handler.py +++ b/STT/lightning_whisper_mlx_handler.py @@ -5,6 +5,7 @@ from lightning_whisper_mlx import LightningWhisperMLX import numpy as np from rich.console import Console import torch + logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) @@ -26,12 +27,10 @@ class LightningWhisperSTTHandler(BaseHandler): compile_mode=None, gen_kwargs={}, ): - if len(model_name.split('/')) > 1: - model_name = model_name.split('/')[-1] + if len(model_name.split("/")) > 1: + model_name = model_name.split("/")[-1] self.device = device - self.model = LightningWhisperMLX( - model=model_name, batch_size=6, quant=None - ) + self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None) self.warmup() def warmup(self): diff --git a/TTS/melotts.py b/TTS/melotts.py index f1a712bb98fc050c03e114dd24e2d62ae63284b8..fad87d93e1e8bd77f0c812653de8ed8edd33025e 100644 --- a/TTS/melotts.py +++ b/TTS/melotts.py @@ -40,10 +40,13 @@ class MeloTTSHandler(BaseHandler): console.print(f"[green]ASSISTANT: {llm_sentence}") if self.device == "mps": import time + start = time.time() torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete. torch.mps.empty_cache() # Frees all memory allocated by the MPS device. - time_it_took = time.time()-start # Removing this line makes it fail more often. I'm looking into it. + _ = ( + time.time() - start + ) # Removing this line makes it fail more often. I'm looking into it. audio_chunk = self.model.tts_to_file(llm_sentence, self.speaker_id, quiet=True) if len(audio_chunk) == 0: diff --git a/arguments_classes/language_model_arguments.py b/arguments_classes/language_model_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..cd66ca3795079a2bd033b22d5355b09c6883530e --- /dev/null +++ b/arguments_classes/language_model_arguments.py @@ -0,0 +1,65 @@ +from dataclasses import dataclass, field + + +@dataclass +class LanguageModelHandlerArguments: + lm_model_name: str = field( + default="HuggingFaceTB/SmolLM-360M-Instruct", + metadata={ + "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." + }, + ) + lm_device: str = field( + default="cuda", + metadata={ + "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." + }, + ) + lm_torch_dtype: str = field( + default="float16", + metadata={ + "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." + }, + ) + user_role: str = field( + default="user", + metadata={ + "help": "Role assigned to the user in the chat context. Default is 'user'." + }, + ) + init_chat_role: str = field( + default="system", + metadata={ + "help": "Initial role for setting up the chat context. Default is 'system'." + }, + ) + init_chat_prompt: str = field( + default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", + metadata={ + "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" + }, + ) + lm_gen_max_new_tokens: int = field( + default=128, + metadata={ + "help": "Maximum number of new tokens to generate in a single completion. Default is 128." + }, + ) + lm_gen_temperature: float = field( + default=0.0, + metadata={ + "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0." + }, + ) + lm_gen_do_sample: bool = field( + default=False, + metadata={ + "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False." + }, + ) + chat_size: int = field( + default=2, + metadata={ + "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." + }, + ) diff --git a/handlers/melo_tts_handler.py b/arguments_classes/melo_tts_arguments.py similarity index 96% rename from handlers/melo_tts_handler.py rename to arguments_classes/melo_tts_arguments.py index 88616c34e1077b4d922fa820ad530dae6a42ac17..49fd3578ac0e376ab264baa7900e5fc0cc879c76 100644 --- a/handlers/melo_tts_handler.py +++ b/arguments_classes/melo_tts_arguments.py @@ -1,6 +1,4 @@ - from dataclasses import dataclass, field -from typing import List @dataclass @@ -23,4 +21,3 @@ class MeloTTSHandlerArguments: "help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']." }, ) - diff --git a/arguments_classes/module_arguments.py b/arguments_classes/module_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..140559641dcb9d908e1e426c89eb0815af318241 --- /dev/null +++ b/arguments_classes/module_arguments.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModuleArguments: + device: Optional[str] = field( + default=None, + metadata={"help": "If specified, overrides the device for all handlers."}, + ) + mode: Optional[str] = field( + default="socket", + metadata={ + "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'." + }, + ) + local_mac_optimal_settings: bool = field( + default=False, + metadata={ + "help": "If specified, sets the optimal settings for Mac OS. Hence whisper-mlx, MLX LM and MeloTTS will be used." + }, + ) + stt: Optional[str] = field( + default="whisper", + metadata={ + "help": "The STT to use. Either 'whisper' or 'whisper-mlx'. Default is 'whisper'." + }, + ) + llm: Optional[str] = field( + default="transformers", + metadata={ + "help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'" + }, + ) + tts: Optional[str] = field( + default="parler", + metadata={ + "help": "The TTS to use. Either 'parler' or 'melo'. Default is 'parler'" + }, + ) + log_level: str = field( + default="info", + metadata={ + "help": "Provide logging level. Example --log_level debug, default=warning." + }, + ) diff --git a/arguments_classes/parler_tts_arguments.py b/arguments_classes/parler_tts_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..515943209d2068bc0e2a451eb08d1cce68cfb67d --- /dev/null +++ b/arguments_classes/parler_tts_arguments.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass, field + + +@dataclass +class ParlerTTSHandlerArguments: + tts_model_name: str = field( + default="ylacombe/parler-tts-mini-jenny-30H", + metadata={ + "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'." + }, + ) + tts_device: str = field( + default="cuda", + metadata={ + "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." + }, + ) + tts_torch_dtype: str = field( + default="float16", + metadata={ + "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." + }, + ) + tts_compile_mode: str = field( + default=None, + metadata={ + "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" + }, + ) + tts_gen_min_new_tokens: int = field( + default=64, + metadata={ + "help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs" + }, + ) + tts_gen_max_new_tokens: int = field( + default=512, + metadata={ + "help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs" + }, + ) + description: str = field( + default=( + "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " + "She speaks very fast." + ), + metadata={ + "help": "Description of the speaker's voice and speaking style to guide the TTS model." + }, + ) + play_steps_s: float = field( + default=1.0, + metadata={ + "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds." + }, + ) + max_prompt_pad_length: int = field( + default=8, + metadata={ + "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible." + }, + ) diff --git a/arguments_classes/socket_receiver_arguments.py b/arguments_classes/socket_receiver_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..2884edd89abdcbc2275ff7aa1f8dc08879a963af --- /dev/null +++ b/arguments_classes/socket_receiver_arguments.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass, field + + +@dataclass +class SocketReceiverArguments: + recv_host: str = field( + default="localhost", + metadata={ + "help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all " + "available interfaces on the host machine." + }, + ) + recv_port: int = field( + default=12345, + metadata={ + "help": "The port number on which the socket server listens. Default is 12346." + }, + ) + chunk_size: int = field( + default=1024, + metadata={ + "help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes." + }, + ) diff --git a/arguments_classes/socket_sender_arguments.py b/arguments_classes/socket_sender_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..8777f6cff6d2edef6e323dfe1db76e884af53ac5 --- /dev/null +++ b/arguments_classes/socket_sender_arguments.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass, field + + +@dataclass +class SocketSenderArguments: + send_host: str = field( + default="localhost", + metadata={ + "help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all " + "available interfaces on the host machine." + }, + ) + send_port: int = field( + default=12346, + metadata={ + "help": "The port number on which the socket server listens. Default is 12346." + }, + ) diff --git a/arguments_classes/vad_arguments.py b/arguments_classes/vad_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..2dfb378819d21c76a374d9ab5d10e588e784be0e --- /dev/null +++ b/arguments_classes/vad_arguments.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass, field + + +@dataclass +class VADHandlerArguments: + thresh: float = field( + default=0.3, + metadata={ + "help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection." + }, + ) + sample_rate: int = field( + default=16000, + metadata={ + "help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio." + }, + ) + min_silence_ms: int = field( + default=250, + metadata={ + "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms." + }, + ) + min_speech_ms: int = field( + default=500, + metadata={ + "help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms." + }, + ) + max_speech_ms: float = field( + default=float("inf"), + metadata={ + "help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments." + }, + ) + speech_pad_ms: int = field( + default=250, + metadata={ + "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms." + }, + ) diff --git a/arguments_classes/whisper_stt_arguments.py b/arguments_classes/whisper_stt_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..bed382dda754da36965b4d86e68a7f8b4d9c322c --- /dev/null +++ b/arguments_classes/whisper_stt_arguments.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass, field + + +@dataclass +class WhisperSTTHandlerArguments: + stt_model_name: str = field( + default="distil-whisper/distil-large-v3", + metadata={ + "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'." + }, + ) + stt_device: str = field( + default="cuda", + metadata={ + "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." + }, + ) + stt_torch_dtype: str = field( + default="float16", + metadata={ + "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." + }, + ) + stt_compile_mode: str = field( + default=None, + metadata={ + "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" + }, + ) + stt_gen_max_new_tokens: int = field( + default=128, + metadata={ + "help": "The maximum number of new tokens to generate. Default is 128." + }, + ) + stt_gen_num_beams: int = field( + default=1, + metadata={ + "help": "The number of beams for beam search. Default is 1, implying greedy decoding." + }, + ) + stt_gen_return_timestamps: bool = field( + default=False, + metadata={ + "help": "Whether to return timestamps with transcriptions. Default is False." + }, + ) + stt_gen_task: str = field( + default="transcribe", + metadata={ + "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'." + }, + ) + stt_gen_language: str = field( + default="en", + metadata={ + "help": "The language of the speech to transcribe. Default is 'en' for English." + }, + ) diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 0299189cd8d8f317366aa2aa9a5f229c15e686a3..093bfb8c97ccbe4729b00ea8ca8fdbc6244f5559 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -4,7 +4,6 @@ import socket import sys import threading from copy import copy -from dataclasses import dataclass, field from pathlib import Path from queue import Queue from threading import Event, Thread @@ -12,9 +11,16 @@ from time import perf_counter from typing import Optional from sys import platform from LLM.mlx_lm import MLXLanguageModelHandler +from arguments_classes.language_model_arguments import LanguageModelHandlerArguments +from arguments_classes.module_arguments import ModuleArguments +from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments +from arguments_classes.socket_receiver_arguments import SocketReceiverArguments +from arguments_classes.socket_sender_arguments import SocketSenderArguments +from arguments_classes.vad_arguments import VADHandlerArguments +from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments from baseHandler import BaseHandler from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler -from handlers.melo_tts_handler import MeloTTSHandlerArguments +from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments import numpy as np import torch import nltk @@ -37,9 +43,9 @@ from utils import VADIterator, int2float, next_power_of_2 # Ensure that the necessary NLTK resources are available try: - nltk.data.find('tokenizers/punkt_tab') + nltk.data.find("tokenizers/punkt_tab") except (LookupError, OSError): - nltk.download('punkt_tab') + nltk.download("punkt_tab") # caching allows ~50% compilation time reduction # see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma @@ -50,50 +56,6 @@ os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") console = Console() -@dataclass -class ModuleArguments: - device: Optional[str] = field( - default=None, - metadata={"help": "If specified, overrides the device for all handlers."}, - ) - mode: Optional[str] = field( - default="socket", - metadata={ - "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'." - }, - ) - local_mac_optimal_settings: bool = field( - default=False, - metadata={ - "help": "If specified, sets the optimal settings for Mac OS. Hence whisper-mlx, MLX LM and MeloTTS will be used." - }, - ) - stt: Optional[str] = field( - default="whisper", - metadata={ - "help": "The STT to use. Either 'whisper' or 'whisper-mlx'. Default is 'whisper'." - }, - ) - llm: Optional[str] = field( - default="transformers", - metadata={ - "help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'" - }, - ) - tts: Optional[str] = field( - default="parler", - metadata={ - "help": "The TTS to use. Either 'parler' or 'melo'. Default is 'parler'" - }, - ) - log_level: str = field( - default="info", - metadata={ - "help": "Provide logging level. Example --log_level debug, default=warning." - }, - ) - - class ThreadManager: """ Manages multiple threads used to execute given handler tasks. @@ -116,29 +78,6 @@ class ThreadManager: thread.join() -@dataclass -class SocketReceiverArguments: - recv_host: str = field( - default="localhost", - metadata={ - "help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all " - "available interfaces on the host machine." - }, - ) - recv_port: int = field( - default=12345, - metadata={ - "help": "The port number on which the socket server listens. Default is 12346." - }, - ) - chunk_size: int = field( - default=1024, - metadata={ - "help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes." - }, - ) - - class SocketReceiver: """ Handles reception of the audio packets from the client. @@ -192,23 +131,6 @@ class SocketReceiver: logger.info("Receiver closed") -@dataclass -class SocketSenderArguments: - send_host: str = field( - default="localhost", - metadata={ - "help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all " - "available interfaces on the host machine." - }, - ) - send_port: int = field( - default=12346, - metadata={ - "help": "The port number on which the socket server listens. Default is 12346." - }, - ) - - class SocketSender: """ Handles sending generated audio packets to the clients. @@ -238,46 +160,6 @@ class SocketSender: logger.info("Sender closed") -@dataclass -class VADHandlerArguments: - thresh: float = field( - default=0.3, - metadata={ - "help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection." - }, - ) - sample_rate: int = field( - default=16000, - metadata={ - "help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio." - }, - ) - min_silence_ms: int = field( - default=250, - metadata={ - "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms." - }, - ) - min_speech_ms: int = field( - default=500, - metadata={ - "help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms." - }, - ) - max_speech_ms: float = field( - default=float("inf"), - metadata={ - "help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments." - }, - ) - speech_pad_ms: int = field( - default=250, - metadata={ - "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms." - }, - ) - - class VADHandler(BaseHandler): """ Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed @@ -326,64 +208,6 @@ class VADHandler(BaseHandler): yield array -@dataclass -class WhisperSTTHandlerArguments: - stt_model_name: str = field( - default="distil-whisper/distil-large-v3", - metadata={ - "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'." - }, - ) - stt_device: str = field( - default="cuda", - metadata={ - "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." - }, - ) - stt_torch_dtype: str = field( - default="float16", - metadata={ - "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." - }, - ) - stt_compile_mode: str = field( - default=None, - metadata={ - "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" - }, - ) - stt_gen_max_new_tokens: int = field( - default=128, - metadata={ - "help": "The maximum number of new tokens to generate. Default is 128." - }, - ) - stt_gen_num_beams: int = field( - default=1, - metadata={ - "help": "The number of beams for beam search. Default is 1, implying greedy decoding." - }, - ) - stt_gen_return_timestamps: bool = field( - default=False, - metadata={ - "help": "Whether to return timestamps with transcriptions. Default is False." - }, - ) - # stt_gen_task: str = field( - # default="transcribe", - # metadata={ - # "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'." - # }, - # ) - # stt_gen_language: str = field( - # default="en", - # metadata={ - # "help": "The language of the speech to transcribe. Default is 'en' for English." - # }, - # ) - - class WhisperSTTHandler(BaseHandler): """ Handles the Speech To Text generation using a Whisper model. @@ -480,70 +304,6 @@ class WhisperSTTHandler(BaseHandler): yield pred_text -@dataclass -class LanguageModelHandlerArguments: - lm_model_name: str = field( - default="HuggingFaceTB/SmolLM-360M-Instruct", - metadata={ - "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." - }, - ) - lm_device: str = field( - default="cuda", - metadata={ - "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." - }, - ) - lm_torch_dtype: str = field( - default="float16", - metadata={ - "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." - }, - ) - user_role: str = field( - default="user", - metadata={ - "help": "Role assigned to the user in the chat context. Default is 'user'." - }, - ) - init_chat_role: str = field( - default='system', - metadata={ - "help": "Initial role for setting up the chat context. Default is 'system'." - }, - ) - init_chat_prompt: str = field( - default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", - metadata={ - "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" - }, - ) - lm_gen_max_new_tokens: int = field( - default=128, - metadata={ - "help": "Maximum number of new tokens to generate in a single completion. Default is 128." - }, - ) - lm_gen_temperature: float = field( - default=0.0, - metadata={ - "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0." - }, - ) - lm_gen_do_sample: bool = field( - default=False, - metadata={ - "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False." - }, - ) - chat_size: int = field( - default=2, - metadata={ - "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." - }, - ) - - class Chat: """ Handles the chat using to avoid OOM issues. @@ -684,67 +444,6 @@ class LanguageModelHandler(BaseHandler): yield printable_text -@dataclass -class ParlerTTSHandlerArguments: - tts_model_name: str = field( - default="ylacombe/parler-tts-mini-jenny-30H", - metadata={ - "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'." - }, - ) - tts_device: str = field( - default="cuda", - metadata={ - "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." - }, - ) - tts_torch_dtype: str = field( - default="float16", - metadata={ - "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." - }, - ) - tts_compile_mode: str = field( - default=None, - metadata={ - "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" - }, - ) - tts_gen_min_new_tokens: int = field( - default=64, - metadata={ - "help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs" - }, - ) - tts_gen_max_new_tokens: int = field( - default=512, - metadata={ - "help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs" - }, - ) - description: str = field( - default=( - "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " - "She speaks very fast." - ), - metadata={ - "help": "Description of the speaker's voice and speaking style to guide the TTS model." - }, - ) - play_steps_s: float = field( - default=1.0, - metadata={ - "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds." - }, - ) - max_prompt_pad_length: int = field( - default=8, - metadata={ - "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible." - }, - ) - - class ParlerTTSHandler(BaseHandler): def setup( self, @@ -886,7 +585,7 @@ class ParlerTTSHandler(BaseHandler): thread.start() for i, audio_chunk in enumerate(streamer): - if i == 0 and 'pipeline_start' in globals(): + if i == 0 and "pipeline_start" in globals(): logger.info( f"Time to first audio: {perf_counter() - pipeline_start:.3f}" ) @@ -971,7 +670,6 @@ def main(): if module_kwargs.log_level == "debug": torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) - def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs): if mac_optimal_settings: for kwargs in handler_kwargs: @@ -1070,14 +768,14 @@ def main(): setup_args=(should_listen,), setup_kwargs=vars(vad_handler_kwargs), ) - if module_kwargs.stt == 'whisper': + if module_kwargs.stt == "whisper": stt = WhisperSTTHandler( - stop_event, - queue_in=spoken_prompt_queue, - queue_out=text_prompt_queue, - setup_kwargs=vars(whisper_stt_handler_kwargs), - ) - elif module_kwargs.stt == 'whisper-mlx': + stop_event, + queue_in=spoken_prompt_queue, + queue_out=text_prompt_queue, + setup_kwargs=vars(whisper_stt_handler_kwargs), + ) + elif module_kwargs.stt == "whisper-mlx": stt = LightningWhisperSTTHandler( stop_event, queue_in=spoken_prompt_queue, @@ -1086,14 +784,14 @@ def main(): ) else: raise ValueError("The STT should be either whisper or whisper-mlx") - if module_kwargs.llm == 'transformers': + if module_kwargs.llm == "transformers": lm = LanguageModelHandler( - stop_event, - queue_in=text_prompt_queue, - queue_out=lm_response_queue, - setup_kwargs=vars(language_model_handler_kwargs), - ) - elif module_kwargs.llm == 'mlx-lm': + stop_event, + queue_in=text_prompt_queue, + queue_out=lm_response_queue, + setup_kwargs=vars(language_model_handler_kwargs), + ) + elif module_kwargs.llm == "mlx-lm": lm = MLXLanguageModelHandler( stop_event, queue_in=text_prompt_queue, @@ -1102,7 +800,7 @@ def main(): ) else: raise ValueError("The LLM should be either transformers or mlx-lm") - if module_kwargs.tts == 'parler': + if module_kwargs.tts == "parler": torch._inductor.config.fx_graph_cache = True # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS torch._dynamo.config.cache_size_limit = 15 @@ -1113,12 +811,14 @@ def main(): setup_args=(should_listen,), setup_kwargs=vars(parler_tts_handler_kwargs), ) - - elif module_kwargs.tts == 'melo': + + elif module_kwargs.tts == "melo": try: from TTS.melotts import MeloTTSHandler except RuntimeError as e: - logger.error(f"Error importing MeloTTSHandler. You might need to run: python -m unidic download") + logger.error( + "Error importing MeloTTSHandler. You might need to run: python -m unidic download" + ) raise e tts = MeloTTSHandler( stop_event, diff --git a/utils.py b/utils.py index f4237a1121e5a398e09bb8249bd37a9bffa81b43..3e2e9bc1d98b5e3b350a7fb865b0b7414af9d3aa 100644 --- a/utils.py +++ b/utils.py @@ -84,7 +84,7 @@ class VADIterator: if not torch.is_tensor(x): try: x = torch.Tensor(x) - except: + except Exception: raise TypeError("Audio cannot be casted to tensor. Cast it manually") window_size_samples = len(x[0]) if x.dim() == 2 else len(x)