From 7c8d24871924968a0026461fc7f205c32842cac0 Mon Sep 17 00:00:00 2001 From: Andres Marafioti <andimarafioti@gmail.com> Date: Fri, 23 Aug 2024 14:13:18 +0200 Subject: [PATCH] Improvements mlx pipeline --- .../mlx_language_model_arguments.py | 65 +++++++++++++++++++ arguments_classes/vad_arguments.py | 2 +- s2s_pipeline.py | 8 ++- 3 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 arguments_classes/mlx_language_model_arguments.py diff --git a/arguments_classes/mlx_language_model_arguments.py b/arguments_classes/mlx_language_model_arguments.py new file mode 100644 index 0000000..0765ec9 --- /dev/null +++ b/arguments_classes/mlx_language_model_arguments.py @@ -0,0 +1,65 @@ +from dataclasses import dataclass, field + + +@dataclass +class MLXLanguageModelHandlerArguments: + mlx_lm_model_name: str = field( + default="mlx-community/SmolLM-360M-Instruct", + metadata={ + "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." + }, + ) + mlx_lm_device: str = field( + default="mps", + metadata={ + "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." + }, + ) + mlx_lm_torch_dtype: str = field( + default="float16", + metadata={ + "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." + }, + ) + mlx_lm_user_role: str = field( + default="user", + metadata={ + "help": "Role assigned to the user in the chat context. Default is 'user'." + }, + ) + mlx_lm_init_chat_role: str = field( + default="system", + metadata={ + "help": "Initial role for setting up the chat context. Default is 'system'." + }, + ) + mlx_lm_init_chat_prompt: str = field( + default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", + metadata={ + "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" + }, + ) + mlx_lm_gen_max_new_tokens: int = field( + default=128, + metadata={ + "help": "Maximum number of new tokens to generate in a single completion. Default is 128." + }, + ) + mlx_lm_gen_temperature: float = field( + default=0.0, + metadata={ + "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0." + }, + ) + mlx_lm_gen_do_sample: bool = field( + default=False, + metadata={ + "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False." + }, + ) + mlx_lm_chat_size: int = field( + default=2, + metadata={ + "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." + }, + ) diff --git a/arguments_classes/vad_arguments.py b/arguments_classes/vad_arguments.py index 2dfb378..450229c 100644 --- a/arguments_classes/vad_arguments.py +++ b/arguments_classes/vad_arguments.py @@ -34,7 +34,7 @@ class VADHandlerArguments: }, ) speech_pad_ms: int = field( - default=250, + default=500, metadata={ "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms." }, diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 4f5b14b..d950c23 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -11,6 +11,7 @@ from time import perf_counter from typing import Optional from sys import platform from arguments_classes.language_model_arguments import LanguageModelHandlerArguments +from arguments_classes.mlx_language_model_arguments import MLXLanguageModelHandlerArguments from arguments_classes.module_arguments import ModuleArguments from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments from arguments_classes.socket_receiver_arguments import SocketReceiverArguments @@ -629,6 +630,7 @@ def main(): VADHandlerArguments, WhisperSTTHandlerArguments, LanguageModelHandlerArguments, + MLXLanguageModelHandlerArguments, ParlerTTSHandlerArguments, MeloTTSHandlerArguments, ) @@ -644,6 +646,7 @@ def main(): vad_handler_kwargs, whisper_stt_handler_kwargs, language_model_handler_kwargs, + mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) @@ -656,6 +659,7 @@ def main(): vad_handler_kwargs, whisper_stt_handler_kwargs, language_model_handler_kwargs, + mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, ) = parser.parse_args_into_dataclasses() @@ -720,12 +724,14 @@ def main(): overwrite_device_argument( module_kwargs.device, language_model_handler_kwargs, + mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, whisper_stt_handler_kwargs, ) prepare_args(whisper_stt_handler_kwargs, "stt") prepare_args(language_model_handler_kwargs, "lm") + prepare_args(mlx_language_model_handler_kwargs, "mlx_lm") prepare_args(parler_tts_handler_kwargs, "tts") prepare_args(melo_tts_handler_kwargs, "melo") @@ -800,7 +806,7 @@ def main(): stop_event, queue_in=text_prompt_queue, queue_out=lm_response_queue, - setup_kwargs=vars(language_model_handler_kwargs), + setup_kwargs=vars(mlx_language_model_handler_kwargs), ) else: raise ValueError("The LLM should be either transformers or mlx-lm") -- GitLab