diff --git a/arguments_classes/parler_tts_arguments.py b/arguments_classes/parler_tts_arguments.py index 515943209d2068bc0e2a451eb08d1cce68cfb67d..977e30ccfc0b3a15b269edb3fc827b2ffb50039f 100644 --- a/arguments_classes/parler_tts_arguments.py +++ b/arguments_classes/parler_tts_arguments.py @@ -3,43 +3,43 @@ from dataclasses import dataclass, field @dataclass class ParlerTTSHandlerArguments: - tts_model_name: str = field( + parler_model_name: str = field( default="ylacombe/parler-tts-mini-jenny-30H", metadata={ "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'." }, ) - tts_device: str = field( + parler_device: str = field( default="cuda", metadata={ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." }, ) - tts_torch_dtype: str = field( + parler_torch_dtype: str = field( default="float16", metadata={ "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." }, ) - tts_compile_mode: str = field( + parler_compile_mode: str = field( default=None, metadata={ "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" }, ) - tts_gen_min_new_tokens: int = field( + parler_gen_min_new_tokens: int = field( default=64, metadata={ "help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs" }, ) - tts_gen_max_new_tokens: int = field( + parler_gen_max_new_tokens: int = field( default=512, metadata={ "help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs" }, ) - description: str = field( + parler_description: str = field( default=( "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " "She speaks very fast." @@ -48,13 +48,13 @@ class ParlerTTSHandlerArguments: "help": "Description of the speaker's voice and speaking style to guide the TTS model." }, ) - play_steps_s: float = field( + parler_play_steps_s: float = field( default=1.0, metadata={ "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds." }, ) - max_prompt_pad_length: int = field( + parler_max_prompt_pad_length: int = field( default=8, metadata={ "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible." diff --git a/arguments_classes/language_model_arguments.py b/arguments_classes/transformers_language_model_arguments.py similarity index 78% rename from arguments_classes/language_model_arguments.py rename to arguments_classes/transformers_language_model_arguments.py index 8680a78fe07ec2d7cb5f6e1a3dcbb0ed3319ac6b..717219ead2f7e99e15c8521886eb1fa4a05af79d 100644 --- a/arguments_classes/language_model_arguments.py +++ b/arguments_classes/transformers_language_model_arguments.py @@ -2,68 +2,68 @@ from dataclasses import dataclass, field @dataclass -class LanguageModelHandlerArguments: - lm_model_name: str = field( +class TransformersLanguageModelHandlerArguments: + transformers_lm_model_name: str = field( default="HuggingFaceTB/SmolLM-360M-Instruct", metadata={ "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." }, ) - lm_device: str = field( + transformers_lm_device: str = field( default="cuda", metadata={ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." }, ) - lm_torch_dtype: str = field( + transformers_lm_torch_dtype: str = field( default="float16", metadata={ "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." }, ) - user_role: str = field( + transformers_lm_user_role: str = field( default="user", metadata={ "help": "Role assigned to the user in the chat context. Default is 'user'." }, ) - init_chat_role: str = field( + transformers_lm_init_chat_role: str = field( default="system", metadata={ "help": "Initial role for setting up the chat context. Default is 'system'." }, ) - init_chat_prompt: str = field( + transformers_lm_init_chat_prompt: str = field( default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", metadata={ "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" }, ) - lm_gen_max_new_tokens: int = field( + transformers_lm_gen_max_new_tokens: int = field( default=128, metadata={ "help": "Maximum number of new tokens to generate in a single completion. Default is 128." }, ) - lm_gen_min_new_tokens: int = field( + transformers_lm_gen_min_new_tokens: int = field( default=0, metadata={ "help": "Minimum number of new tokens to generate in a single completion. Default is 0." }, ) - lm_gen_temperature: float = field( + transformers_lm_gen_temperature: float = field( default=0.0, metadata={ "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0." }, ) - lm_gen_do_sample: bool = field( + transformers_lm_gen_do_sample: bool = field( default=False, metadata={ "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False." }, ) - chat_size: int = field( + transformers_lm_chat_size: int = field( default=2, metadata={ "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." diff --git a/arguments_classes/whisper_stt_arguments.py b/arguments_classes/whisper_stt_arguments.py index 5dc700bf24e2320d0065ab6db40c0adbcf4782b5..cf06ca4a53fd198b03d407a2c6fdbf479233c23e 100644 --- a/arguments_classes/whisper_stt_arguments.py +++ b/arguments_classes/whisper_stt_arguments.py @@ -4,49 +4,49 @@ from typing import Optional @dataclass class WhisperSTTHandlerArguments: - stt_model_name: str = field( + whisper_model_name: str = field( default="distil-whisper/distil-large-v3", metadata={ "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'." }, ) - stt_device: str = field( + whisper_device: str = field( default="cuda", metadata={ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." }, ) - stt_torch_dtype: str = field( + whisper_torch_dtype: str = field( default="float16", metadata={ "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." }, ) - stt_compile_mode: str = field( + whisper_compile_mode: str = field( default=None, metadata={ "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" }, ) - stt_gen_max_new_tokens: int = field( + whisper_gen_max_new_tokens: int = field( default=128, metadata={ "help": "The maximum number of new tokens to generate. Default is 128." }, ) - stt_gen_num_beams: int = field( + whisper_gen_num_beams: int = field( default=1, metadata={ "help": "The number of beams for beam search. Default is 1, implying greedy decoding." }, ) - stt_gen_return_timestamps: bool = field( + whisper_gen_return_timestamps: bool = field( default=False, metadata={ "help": "Whether to return timestamps with transcriptions. Default is False." }, ) - stt_gen_task: str = field( + whisper_gen_task: str = field( default="transcribe", metadata={ "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'." diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 1da202e200c825ad8473f1e115d41cd5f8f686ff..d40880e3a6f72b9ce743d52b98ebf43995f2d434 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -9,7 +9,7 @@ from typing import Optional from sys import platform from VAD.vad_handler import VADHandler from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments -from arguments_classes.language_model_arguments import LanguageModelHandlerArguments +from arguments_classes.transformers_language_model_arguments import TransformersLanguageModelHandlerArguments from arguments_classes.mlx_language_model_arguments import ( MLXLanguageModelHandlerArguments, ) @@ -76,7 +76,7 @@ def parse_arguments(): VADHandlerArguments, WhisperSTTHandlerArguments, ParaformerSTTHandlerArguments, - LanguageModelHandlerArguments, + TransformersLanguageModelHandlerArguments, OpenApiLanguageModelHandlerArguments, MLXLanguageModelHandlerArguments, ParlerTTSHandlerArguments, @@ -161,7 +161,7 @@ def prepare_all_args( module_kwargs, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, - language_model_handler_kwargs, + transformers_language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, @@ -172,7 +172,7 @@ def prepare_all_args( module_kwargs, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, - language_model_handler_kwargs, + transformers_language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, @@ -181,12 +181,12 @@ def prepare_all_args( ) - rename_args(whisper_stt_handler_kwargs, "stt") + rename_args(whisper_stt_handler_kwargs, "whisper") rename_args(paraformer_stt_handler_kwargs, "paraformer_stt") - rename_args(language_model_handler_kwargs, "lm") + rename_args(transformers_language_model_handler_kwargs, "transformers_lm") rename_args(mlx_language_model_handler_kwargs, "mlx_lm") rename_args(open_api_language_model_handler_kwargs, "open_api") - rename_args(parler_tts_handler_kwargs, "tts") + rename_args(parler_tts_handler_kwargs, "parler") rename_args(melo_tts_handler_kwargs, "melo") rename_args(chat_tts_handler_kwargs, "chat_tts") @@ -210,7 +210,7 @@ def build_pipeline( vad_handler_kwargs, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, - language_model_handler_kwargs, + transformers_language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, @@ -262,14 +262,45 @@ def build_pipeline( setup_kwargs=vars(vad_handler_kwargs), ) - stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs) - lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs) - tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs) + stt = get_stt_handler( + module_kwargs, + stop_event, + spoken_prompt_queue, + text_prompt_queue, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs + ) + lm = get_llm_handler( + module_kwargs, + stop_event, + text_prompt_queue, + lm_response_queue, + transformers_language_model_handler_kwargs, + open_api_language_model_handler_kwargs, + mlx_language_model_handler_kwargs + ) + tts = get_tts_handler( + module_kwargs, + stop_event, + lm_response_queue, + send_audio_chunks_queue, + should_listen, + parler_tts_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs + ) return ThreadManager([*comms_handlers, vad, stt, lm, tts]) -def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs): +def get_stt_handler( + module_kwargs, + stop_event, + spoken_prompt_queue, + text_prompt_queue, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs +): if module_kwargs.stt == "whisper": from STT.whisper_stt_handler import WhisperSTTHandler return WhisperSTTHandler( @@ -337,7 +368,16 @@ def get_llm_handler( raise ValueError("The LLM should be either transformers or mlx-lm") -def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs): +def get_tts_handler( + module_kwargs, + stop_event, + lm_response_queue, + send_audio_chunks_queue, + should_listen, + parler_tts_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs +): if module_kwargs.tts == "parler": from TTS.parler_handler import ParlerTTSHandler return ParlerTTSHandler( @@ -387,7 +427,7 @@ def main(): vad_handler_kwargs, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, - language_model_handler_kwargs, + transformers_language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, @@ -401,14 +441,14 @@ def main(): module_kwargs, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, - language_model_handler_kwargs, + transformers_language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, ) - + queues_and_events = initialize_queues_and_events() pipeline_manager = build_pipeline( @@ -418,7 +458,7 @@ def main(): vad_handler_kwargs, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, - language_model_handler_kwargs, + transformers_language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs,