From c67aec05d25022f770b970ade682b7a4d04fff67 Mon Sep 17 00:00:00 2001
From: Eustache Le Bihan <eulebihan@gmail.com>
Date: Mon, 23 Sep 2024 14:28:58 +0200
Subject: [PATCH] revert breaking changes

---
 ...guments.py => language_model_arguments.py} | 24 +++---
 arguments_classes/parler_tts_arguments.py     | 18 ++---
 arguments_classes/whisper_stt_arguments.py    | 16 ++--
 s2s_pipeline.py                               | 74 +++++--------------
 4 files changed, 46 insertions(+), 86 deletions(-)
 rename arguments_classes/{transformers_language_model_arguments.py => language_model_arguments.py} (78%)

diff --git a/arguments_classes/transformers_language_model_arguments.py b/arguments_classes/language_model_arguments.py
similarity index 78%
rename from arguments_classes/transformers_language_model_arguments.py
rename to arguments_classes/language_model_arguments.py
index 717219e..8680a78 100644
--- a/arguments_classes/transformers_language_model_arguments.py
+++ b/arguments_classes/language_model_arguments.py
@@ -2,68 +2,68 @@ from dataclasses import dataclass, field
 
 
 @dataclass
-class TransformersLanguageModelHandlerArguments:
-    transformers_lm_model_name: str = field(
+class LanguageModelHandlerArguments:
+    lm_model_name: str = field(
         default="HuggingFaceTB/SmolLM-360M-Instruct",
         metadata={
             "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
         },
     )
-    transformers_lm_device: str = field(
+    lm_device: str = field(
         default="cuda",
         metadata={
             "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
         },
     )
-    transformers_lm_torch_dtype: str = field(
+    lm_torch_dtype: str = field(
         default="float16",
         metadata={
             "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
         },
     )
-    transformers_lm_user_role: str = field(
+    user_role: str = field(
         default="user",
         metadata={
             "help": "Role assigned to the user in the chat context. Default is 'user'."
         },
     )
-    transformers_lm_init_chat_role: str = field(
+    init_chat_role: str = field(
         default="system",
         metadata={
             "help": "Initial role for setting up the chat context. Default is 'system'."
         },
     )
-    transformers_lm_init_chat_prompt: str = field(
+    init_chat_prompt: str = field(
         default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
         metadata={
             "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
         },
     )
-    transformers_lm_gen_max_new_tokens: int = field(
+    lm_gen_max_new_tokens: int = field(
         default=128,
         metadata={
             "help": "Maximum number of new tokens to generate in a single completion. Default is 128."
         },
     )
-    transformers_lm_gen_min_new_tokens: int = field(
+    lm_gen_min_new_tokens: int = field(
         default=0,
         metadata={
             "help": "Minimum number of new tokens to generate in a single completion. Default is 0."
         },
     )
-    transformers_lm_gen_temperature: float = field(
+    lm_gen_temperature: float = field(
         default=0.0,
         metadata={
             "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
         },
     )
-    transformers_lm_gen_do_sample: bool = field(
+    lm_gen_do_sample: bool = field(
         default=False,
         metadata={
             "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
         },
     )
-    transformers_lm_chat_size: int = field(
+    chat_size: int = field(
         default=2,
         metadata={
             "help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
diff --git a/arguments_classes/parler_tts_arguments.py b/arguments_classes/parler_tts_arguments.py
index 977e30c..5159432 100644
--- a/arguments_classes/parler_tts_arguments.py
+++ b/arguments_classes/parler_tts_arguments.py
@@ -3,43 +3,43 @@ from dataclasses import dataclass, field
 
 @dataclass
 class ParlerTTSHandlerArguments:
-    parler_model_name: str = field(
+    tts_model_name: str = field(
         default="ylacombe/parler-tts-mini-jenny-30H",
         metadata={
             "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'."
         },
     )
-    parler_device: str = field(
+    tts_device: str = field(
         default="cuda",
         metadata={
             "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
         },
     )
-    parler_torch_dtype: str = field(
+    tts_torch_dtype: str = field(
         default="float16",
         metadata={
             "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
         },
     )
-    parler_compile_mode: str = field(
+    tts_compile_mode: str = field(
         default=None,
         metadata={
             "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
         },
     )
-    parler_gen_min_new_tokens: int = field(
+    tts_gen_min_new_tokens: int = field(
         default=64,
         metadata={
             "help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"
         },
     )
-    parler_gen_max_new_tokens: int = field(
+    tts_gen_max_new_tokens: int = field(
         default=512,
         metadata={
             "help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"
         },
     )
-    parler_description: str = field(
+    description: str = field(
         default=(
             "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
             "She speaks very fast."
@@ -48,13 +48,13 @@ class ParlerTTSHandlerArguments:
             "help": "Description of the speaker's voice and speaking style to guide the TTS model."
         },
     )
-    parler_play_steps_s: float = field(
+    play_steps_s: float = field(
         default=1.0,
         metadata={
             "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
         },
     )
-    parler_max_prompt_pad_length: int = field(
+    max_prompt_pad_length: int = field(
         default=8,
         metadata={
             "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible."
diff --git a/arguments_classes/whisper_stt_arguments.py b/arguments_classes/whisper_stt_arguments.py
index cf06ca4..5dc700b 100644
--- a/arguments_classes/whisper_stt_arguments.py
+++ b/arguments_classes/whisper_stt_arguments.py
@@ -4,49 +4,49 @@ from typing import Optional
 
 @dataclass
 class WhisperSTTHandlerArguments:
-    whisper_model_name: str = field(
+    stt_model_name: str = field(
         default="distil-whisper/distil-large-v3",
         metadata={
             "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'."
         },
     )
-    whisper_device: str = field(
+    stt_device: str = field(
         default="cuda",
         metadata={
             "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
         },
     )
-    whisper_torch_dtype: str = field(
+    stt_torch_dtype: str = field(
         default="float16",
         metadata={
             "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
         },
     )
-    whisper_compile_mode: str = field(
+    stt_compile_mode: str = field(
         default=None,
         metadata={
             "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
         },
     )
-    whisper_gen_max_new_tokens: int = field(
+    stt_gen_max_new_tokens: int = field(
         default=128,
         metadata={
             "help": "The maximum number of new tokens to generate. Default is 128."
         },
     )
-    whisper_gen_num_beams: int = field(
+    stt_gen_num_beams: int = field(
         default=1,
         metadata={
             "help": "The number of beams for beam search. Default is 1, implying greedy decoding."
         },
     )
-    whisper_gen_return_timestamps: bool = field(
+    stt_gen_return_timestamps: bool = field(
         default=False,
         metadata={
             "help": "Whether to return timestamps with transcriptions. Default is False."
         },
     )
-    whisper_gen_task: str = field(
+    stt_gen_task: str = field(
         default="transcribe",
         metadata={
             "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index d40880e..1da202e 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -9,7 +9,7 @@ from typing import Optional
 from sys import platform
 from VAD.vad_handler import VADHandler
 from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments
-from arguments_classes.transformers_language_model_arguments import TransformersLanguageModelHandlerArguments
+from arguments_classes.language_model_arguments import LanguageModelHandlerArguments
 from arguments_classes.mlx_language_model_arguments import (
     MLXLanguageModelHandlerArguments,
 )
@@ -76,7 +76,7 @@ def parse_arguments():
             VADHandlerArguments,
             WhisperSTTHandlerArguments,
             ParaformerSTTHandlerArguments,
-            TransformersLanguageModelHandlerArguments,
+            LanguageModelHandlerArguments,
             OpenApiLanguageModelHandlerArguments,
             MLXLanguageModelHandlerArguments,
             ParlerTTSHandlerArguments,
@@ -161,7 +161,7 @@ def prepare_all_args(
     module_kwargs,
     whisper_stt_handler_kwargs,
     paraformer_stt_handler_kwargs,
-    transformers_language_model_handler_kwargs,
+    language_model_handler_kwargs,
     open_api_language_model_handler_kwargs,
     mlx_language_model_handler_kwargs,
     parler_tts_handler_kwargs,
@@ -172,7 +172,7 @@ def prepare_all_args(
         module_kwargs,
         whisper_stt_handler_kwargs,
         paraformer_stt_handler_kwargs,
-        transformers_language_model_handler_kwargs,
+        language_model_handler_kwargs,
         open_api_language_model_handler_kwargs,
         mlx_language_model_handler_kwargs,
         parler_tts_handler_kwargs,
@@ -181,12 +181,12 @@ def prepare_all_args(
     )
 
 
-    rename_args(whisper_stt_handler_kwargs, "whisper")
+    rename_args(whisper_stt_handler_kwargs, "stt")
     rename_args(paraformer_stt_handler_kwargs, "paraformer_stt")
-    rename_args(transformers_language_model_handler_kwargs, "transformers_lm")
+    rename_args(language_model_handler_kwargs, "lm")
     rename_args(mlx_language_model_handler_kwargs, "mlx_lm")
     rename_args(open_api_language_model_handler_kwargs, "open_api")
-    rename_args(parler_tts_handler_kwargs, "parler")
+    rename_args(parler_tts_handler_kwargs, "tts")
     rename_args(melo_tts_handler_kwargs, "melo")
     rename_args(chat_tts_handler_kwargs, "chat_tts")
 
@@ -210,7 +210,7 @@ def build_pipeline(
     vad_handler_kwargs,
     whisper_stt_handler_kwargs,
     paraformer_stt_handler_kwargs,
-    transformers_language_model_handler_kwargs,
+    language_model_handler_kwargs,
     open_api_language_model_handler_kwargs,
     mlx_language_model_handler_kwargs,
     parler_tts_handler_kwargs,
@@ -262,45 +262,14 @@ def build_pipeline(
         setup_kwargs=vars(vad_handler_kwargs),
     )
 
-    stt = get_stt_handler(
-        module_kwargs, 
-        stop_event, 
-        spoken_prompt_queue, 
-        text_prompt_queue, 
-        whisper_stt_handler_kwargs, 
-        paraformer_stt_handler_kwargs
-    )
-    lm = get_llm_handler(
-        module_kwargs, 
-        stop_event, 
-        text_prompt_queue, 
-        lm_response_queue, 
-        transformers_language_model_handler_kwargs, 
-        open_api_language_model_handler_kwargs, 
-        mlx_language_model_handler_kwargs
-    )
-    tts = get_tts_handler(
-        module_kwargs, 
-        stop_event, 
-        lm_response_queue, 
-        send_audio_chunks_queue, 
-        should_listen, 
-        parler_tts_handler_kwargs, 
-        melo_tts_handler_kwargs, 
-        chat_tts_handler_kwargs
-    )
+    stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs)
+    lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs)
+    tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs)
 
     return ThreadManager([*comms_handlers, vad, stt, lm, tts])
 
 
-def get_stt_handler(
-    module_kwargs, 
-    stop_event, 
-    spoken_prompt_queue, 
-    text_prompt_queue, 
-    whisper_stt_handler_kwargs, 
-    paraformer_stt_handler_kwargs
-):
+def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs):
     if module_kwargs.stt == "whisper":
         from STT.whisper_stt_handler import WhisperSTTHandler
         return WhisperSTTHandler(
@@ -368,16 +337,7 @@ def get_llm_handler(
         raise ValueError("The LLM should be either transformers or mlx-lm")
 
 
-def get_tts_handler(
-    module_kwargs, 
-    stop_event, 
-    lm_response_queue, 
-    send_audio_chunks_queue, 
-    should_listen, 
-    parler_tts_handler_kwargs, 
-    melo_tts_handler_kwargs, 
-    chat_tts_handler_kwargs
-):
+def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs):
     if module_kwargs.tts == "parler":
         from TTS.parler_handler import ParlerTTSHandler
         return ParlerTTSHandler(
@@ -427,7 +387,7 @@ def main():
         vad_handler_kwargs,
         whisper_stt_handler_kwargs,
         paraformer_stt_handler_kwargs,
-        transformers_language_model_handler_kwargs,
+        language_model_handler_kwargs,
         open_api_language_model_handler_kwargs,
         mlx_language_model_handler_kwargs,
         parler_tts_handler_kwargs,
@@ -441,14 +401,14 @@ def main():
         module_kwargs,
         whisper_stt_handler_kwargs,
         paraformer_stt_handler_kwargs,
-        transformers_language_model_handler_kwargs,
+        language_model_handler_kwargs,
         open_api_language_model_handler_kwargs,
         mlx_language_model_handler_kwargs,
         parler_tts_handler_kwargs,
         melo_tts_handler_kwargs,
         chat_tts_handler_kwargs,
     )
-    
+
     queues_and_events = initialize_queues_and_events()
 
     pipeline_manager = build_pipeline(
@@ -458,7 +418,7 @@ def main():
         vad_handler_kwargs,
         whisper_stt_handler_kwargs,
         paraformer_stt_handler_kwargs,
-        transformers_language_model_handler_kwargs,
+        language_model_handler_kwargs,
         open_api_language_model_handler_kwargs,
         mlx_language_model_handler_kwargs,
         parler_tts_handler_kwargs,
-- 
GitLab