From 7c8d24871924968a0026461fc7f205c32842cac0 Mon Sep 17 00:00:00 2001
From: Andres Marafioti <andimarafioti@gmail.com>
Date: Fri, 23 Aug 2024 14:13:18 +0200
Subject: [PATCH] Improvements mlx pipeline

---
 .../mlx_language_model_arguments.py           | 65 +++++++++++++++++++
 arguments_classes/vad_arguments.py            |  2 +-
 s2s_pipeline.py                               |  8 ++-
 3 files changed, 73 insertions(+), 2 deletions(-)
 create mode 100644 arguments_classes/mlx_language_model_arguments.py

diff --git a/arguments_classes/mlx_language_model_arguments.py b/arguments_classes/mlx_language_model_arguments.py
new file mode 100644
index 0000000..0765ec9
--- /dev/null
+++ b/arguments_classes/mlx_language_model_arguments.py
@@ -0,0 +1,65 @@
+from dataclasses import dataclass, field
+
+
+@dataclass
+class MLXLanguageModelHandlerArguments:
+    mlx_lm_model_name: str = field(
+        default="mlx-community/SmolLM-360M-Instruct",
+        metadata={
+            "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
+        },
+    )
+    mlx_lm_device: str = field(
+        default="mps",
+        metadata={
+            "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
+        },
+    )
+    mlx_lm_torch_dtype: str = field(
+        default="float16",
+        metadata={
+            "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
+        },
+    )
+    mlx_lm_user_role: str = field(
+        default="user",
+        metadata={
+            "help": "Role assigned to the user in the chat context. Default is 'user'."
+        },
+    )
+    mlx_lm_init_chat_role: str = field(
+        default="system",
+        metadata={
+            "help": "Initial role for setting up the chat context. Default is 'system'."
+        },
+    )
+    mlx_lm_init_chat_prompt: str = field(
+        default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
+        metadata={
+            "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
+        },
+    )
+    mlx_lm_gen_max_new_tokens: int = field(
+        default=128,
+        metadata={
+            "help": "Maximum number of new tokens to generate in a single completion. Default is 128."
+        },
+    )
+    mlx_lm_gen_temperature: float = field(
+        default=0.0,
+        metadata={
+            "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
+        },
+    )
+    mlx_lm_gen_do_sample: bool = field(
+        default=False,
+        metadata={
+            "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
+        },
+    )
+    mlx_lm_chat_size: int = field(
+        default=2,
+        metadata={
+            "help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
+        },
+    )
diff --git a/arguments_classes/vad_arguments.py b/arguments_classes/vad_arguments.py
index 2dfb378..450229c 100644
--- a/arguments_classes/vad_arguments.py
+++ b/arguments_classes/vad_arguments.py
@@ -34,7 +34,7 @@ class VADHandlerArguments:
         },
     )
     speech_pad_ms: int = field(
-        default=250,
+        default=500,
         metadata={
             "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms."
         },
diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index 4f5b14b..d950c23 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -11,6 +11,7 @@ from time import perf_counter
 from typing import Optional
 from sys import platform
 from arguments_classes.language_model_arguments import LanguageModelHandlerArguments
+from arguments_classes.mlx_language_model_arguments import MLXLanguageModelHandlerArguments
 from arguments_classes.module_arguments import ModuleArguments
 from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments
 from arguments_classes.socket_receiver_arguments import SocketReceiverArguments
@@ -629,6 +630,7 @@ def main():
             VADHandlerArguments,
             WhisperSTTHandlerArguments,
             LanguageModelHandlerArguments,
+            MLXLanguageModelHandlerArguments,
             ParlerTTSHandlerArguments,
             MeloTTSHandlerArguments,
         )
@@ -644,6 +646,7 @@ def main():
             vad_handler_kwargs,
             whisper_stt_handler_kwargs,
             language_model_handler_kwargs,
+            mlx_language_model_handler_kwargs,
             parler_tts_handler_kwargs,
             melo_tts_handler_kwargs,
         ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
@@ -656,6 +659,7 @@ def main():
             vad_handler_kwargs,
             whisper_stt_handler_kwargs,
             language_model_handler_kwargs,
+            mlx_language_model_handler_kwargs,
             parler_tts_handler_kwargs,
             melo_tts_handler_kwargs,
         ) = parser.parse_args_into_dataclasses()
@@ -720,12 +724,14 @@ def main():
     overwrite_device_argument(
         module_kwargs.device,
         language_model_handler_kwargs,
+        mlx_language_model_handler_kwargs,
         parler_tts_handler_kwargs,
         whisper_stt_handler_kwargs,
     )
 
     prepare_args(whisper_stt_handler_kwargs, "stt")
     prepare_args(language_model_handler_kwargs, "lm")
+    prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
     prepare_args(parler_tts_handler_kwargs, "tts")
     prepare_args(melo_tts_handler_kwargs, "melo")
 
@@ -800,7 +806,7 @@ def main():
             stop_event,
             queue_in=text_prompt_queue,
             queue_out=lm_response_queue,
-            setup_kwargs=vars(language_model_handler_kwargs),
+            setup_kwargs=vars(mlx_language_model_handler_kwargs),
         )
     else:
         raise ValueError("The LLM should be either transformers or mlx-lm")
-- 
GitLab