Skip to content
Snippets Groups Projects
Commit 7c8d2487 authored by Andres Marafioti's avatar Andres Marafioti
Browse files

Improvements mlx pipeline

parent e417e55c
No related branches found
No related tags found
No related merge requests found
from dataclasses import dataclass, field
@dataclass
class MLXLanguageModelHandlerArguments:
mlx_lm_model_name: str = field(
default="mlx-community/SmolLM-360M-Instruct",
metadata={
"help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
},
)
mlx_lm_device: str = field(
default="mps",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
},
)
mlx_lm_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
},
)
mlx_lm_user_role: str = field(
default="user",
metadata={
"help": "Role assigned to the user in the chat context. Default is 'user'."
},
)
mlx_lm_init_chat_role: str = field(
default="system",
metadata={
"help": "Initial role for setting up the chat context. Default is 'system'."
},
)
mlx_lm_init_chat_prompt: str = field(
default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
metadata={
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
},
)
mlx_lm_gen_max_new_tokens: int = field(
default=128,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 128."
},
)
mlx_lm_gen_temperature: float = field(
default=0.0,
metadata={
"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
},
)
mlx_lm_gen_do_sample: bool = field(
default=False,
metadata={
"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
},
)
mlx_lm_chat_size: int = field(
default=2,
metadata={
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
},
)
...@@ -34,7 +34,7 @@ class VADHandlerArguments: ...@@ -34,7 +34,7 @@ class VADHandlerArguments:
}, },
) )
speech_pad_ms: int = field( speech_pad_ms: int = field(
default=250, default=500,
metadata={ metadata={
"help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms." "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms."
}, },
......
...@@ -11,6 +11,7 @@ from time import perf_counter ...@@ -11,6 +11,7 @@ from time import perf_counter
from typing import Optional from typing import Optional
from sys import platform from sys import platform
from arguments_classes.language_model_arguments import LanguageModelHandlerArguments from arguments_classes.language_model_arguments import LanguageModelHandlerArguments
from arguments_classes.mlx_language_model_arguments import MLXLanguageModelHandlerArguments
from arguments_classes.module_arguments import ModuleArguments from arguments_classes.module_arguments import ModuleArguments
from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments
from arguments_classes.socket_receiver_arguments import SocketReceiverArguments from arguments_classes.socket_receiver_arguments import SocketReceiverArguments
...@@ -629,6 +630,7 @@ def main(): ...@@ -629,6 +630,7 @@ def main():
VADHandlerArguments, VADHandlerArguments,
WhisperSTTHandlerArguments, WhisperSTTHandlerArguments,
LanguageModelHandlerArguments, LanguageModelHandlerArguments,
MLXLanguageModelHandlerArguments,
ParlerTTSHandlerArguments, ParlerTTSHandlerArguments,
MeloTTSHandlerArguments, MeloTTSHandlerArguments,
) )
...@@ -644,6 +646,7 @@ def main(): ...@@ -644,6 +646,7 @@ def main():
vad_handler_kwargs, vad_handler_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
...@@ -656,6 +659,7 @@ def main(): ...@@ -656,6 +659,7 @@ def main():
vad_handler_kwargs, vad_handler_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
) = parser.parse_args_into_dataclasses() ) = parser.parse_args_into_dataclasses()
...@@ -720,12 +724,14 @@ def main(): ...@@ -720,12 +724,14 @@ def main():
overwrite_device_argument( overwrite_device_argument(
module_kwargs.device, module_kwargs.device,
language_model_handler_kwargs, language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
) )
prepare_args(whisper_stt_handler_kwargs, "stt") prepare_args(whisper_stt_handler_kwargs, "stt")
prepare_args(language_model_handler_kwargs, "lm") prepare_args(language_model_handler_kwargs, "lm")
prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
prepare_args(parler_tts_handler_kwargs, "tts") prepare_args(parler_tts_handler_kwargs, "tts")
prepare_args(melo_tts_handler_kwargs, "melo") prepare_args(melo_tts_handler_kwargs, "melo")
...@@ -800,7 +806,7 @@ def main(): ...@@ -800,7 +806,7 @@ def main():
stop_event, stop_event,
queue_in=text_prompt_queue, queue_in=text_prompt_queue,
queue_out=lm_response_queue, queue_out=lm_response_queue,
setup_kwargs=vars(language_model_handler_kwargs), setup_kwargs=vars(mlx_language_model_handler_kwargs),
) )
else: else:
raise ValueError("The LLM should be either transformers or mlx-lm") raise ValueError("The LLM should be either transformers or mlx-lm")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment