Skip to content
Snippets Groups Projects
Unverified Commit 30cfb4b7 authored by Andrés Marafioti's avatar Andrés Marafioti Committed by GitHub
Browse files

Merge pull request #41 from huggingface/mlx_lm_improvements

Improvements mlx pipeline
parents e417e55c 7c8d2487
No related branches found
No related tags found
No related merge requests found
from dataclasses import dataclass, field
@dataclass
class MLXLanguageModelHandlerArguments:
mlx_lm_model_name: str = field(
default="mlx-community/SmolLM-360M-Instruct",
metadata={
"help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
},
)
mlx_lm_device: str = field(
default="mps",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
},
)
mlx_lm_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
},
)
mlx_lm_user_role: str = field(
default="user",
metadata={
"help": "Role assigned to the user in the chat context. Default is 'user'."
},
)
mlx_lm_init_chat_role: str = field(
default="system",
metadata={
"help": "Initial role for setting up the chat context. Default is 'system'."
},
)
mlx_lm_init_chat_prompt: str = field(
default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
metadata={
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
},
)
mlx_lm_gen_max_new_tokens: int = field(
default=128,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 128."
},
)
mlx_lm_gen_temperature: float = field(
default=0.0,
metadata={
"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
},
)
mlx_lm_gen_do_sample: bool = field(
default=False,
metadata={
"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
},
)
mlx_lm_chat_size: int = field(
default=2,
metadata={
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
},
)
...@@ -34,7 +34,7 @@ class VADHandlerArguments: ...@@ -34,7 +34,7 @@ class VADHandlerArguments:
}, },
) )
speech_pad_ms: int = field( speech_pad_ms: int = field(
default=250, default=500,
metadata={ metadata={
"help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms." "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms."
}, },
......
...@@ -11,6 +11,7 @@ from time import perf_counter ...@@ -11,6 +11,7 @@ from time import perf_counter
from typing import Optional from typing import Optional
from sys import platform from sys import platform
from arguments_classes.language_model_arguments import LanguageModelHandlerArguments from arguments_classes.language_model_arguments import LanguageModelHandlerArguments
from arguments_classes.mlx_language_model_arguments import MLXLanguageModelHandlerArguments
from arguments_classes.module_arguments import ModuleArguments from arguments_classes.module_arguments import ModuleArguments
from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments
from arguments_classes.socket_receiver_arguments import SocketReceiverArguments from arguments_classes.socket_receiver_arguments import SocketReceiverArguments
...@@ -629,6 +630,7 @@ def main(): ...@@ -629,6 +630,7 @@ def main():
VADHandlerArguments, VADHandlerArguments,
WhisperSTTHandlerArguments, WhisperSTTHandlerArguments,
LanguageModelHandlerArguments, LanguageModelHandlerArguments,
MLXLanguageModelHandlerArguments,
ParlerTTSHandlerArguments, ParlerTTSHandlerArguments,
MeloTTSHandlerArguments, MeloTTSHandlerArguments,
) )
...@@ -644,6 +646,7 @@ def main(): ...@@ -644,6 +646,7 @@ def main():
vad_handler_kwargs, vad_handler_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
...@@ -656,6 +659,7 @@ def main(): ...@@ -656,6 +659,7 @@ def main():
vad_handler_kwargs, vad_handler_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
) = parser.parse_args_into_dataclasses() ) = parser.parse_args_into_dataclasses()
...@@ -720,12 +724,14 @@ def main(): ...@@ -720,12 +724,14 @@ def main():
overwrite_device_argument( overwrite_device_argument(
module_kwargs.device, module_kwargs.device,
language_model_handler_kwargs, language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
) )
prepare_args(whisper_stt_handler_kwargs, "stt") prepare_args(whisper_stt_handler_kwargs, "stt")
prepare_args(language_model_handler_kwargs, "lm") prepare_args(language_model_handler_kwargs, "lm")
prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
prepare_args(parler_tts_handler_kwargs, "tts") prepare_args(parler_tts_handler_kwargs, "tts")
prepare_args(melo_tts_handler_kwargs, "melo") prepare_args(melo_tts_handler_kwargs, "melo")
...@@ -800,7 +806,7 @@ def main(): ...@@ -800,7 +806,7 @@ def main():
stop_event, stop_event,
queue_in=text_prompt_queue, queue_in=text_prompt_queue,
queue_out=lm_response_queue, queue_out=lm_response_queue,
setup_kwargs=vars(language_model_handler_kwargs), setup_kwargs=vars(mlx_language_model_handler_kwargs),
) )
else: else:
raise ValueError("The LLM should be either transformers or mlx-lm") raise ValueError("The LLM should be either transformers or mlx-lm")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment