From a4888572b535d92eb7803e85a31130b6d605fdff Mon Sep 17 00:00:00 2001 From: Andres Marafioti <andimarafioti@gmail.com> Date: Thu, 22 Aug 2024 13:55:48 +0200 Subject: [PATCH] Let users choose their llm --- s2s_pipeline.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/s2s_pipeline.py b/s2s_pipeline.py index b235c65..66aeaa5 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -10,7 +10,7 @@ from queue import Queue from threading import Event, Thread from time import perf_counter from typing import Optional - +from sys import platform from LLM.mlx_lm import MLXLanguageModelHandler from TTS.melotts import MeloTTSHandler from baseHandler import BaseHandler @@ -62,6 +62,12 @@ class ModuleArguments: "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'." }, ) + llm: Optional[str] = field( + default="transformers", + metadata={ + "help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'" + }, + ) tts: Optional[str] = field( default="parler", metadata={ @@ -950,6 +956,20 @@ def main(): if module_kwargs.log_level == "debug": torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) + if platform == "darwin": + if module_kwargs.device == "cuda": + raise ValueError( + "Cannot use CUDA on macOS. Please set the device to 'cpu' or 'mps'." + ) + if module_kwargs.llm != "mlx-lm": + logger.warning( + "For macOS users, it is recommended to use mlx-lm." + ) + if module_kwargs.tts != "melo": + logger.warning( + "If you experiences issues generating the voice, considering setting the tts to melo." + ) + # 2. Prepare each part's arguments def overwrite_device_argument(common_device: Optional[str], *handler_kwargs): if common_device: @@ -1020,12 +1040,22 @@ def main(): queue_out=text_prompt_queue, setup_kwargs=vars(whisper_stt_handler_kwargs), ) - lm = MLXLanguageModelHandler( + if module_kwargs.llm == 'transformers': + lm = LanguageModelHandler( stop_event, queue_in=text_prompt_queue, queue_out=lm_response_queue, setup_kwargs=vars(language_model_handler_kwargs), ) + elif module_kwargs.llm == 'mlx-lm': + lm = MLXLanguageModelHandler( + stop_event, + queue_in=text_prompt_queue, + queue_out=lm_response_queue, + setup_kwargs=vars(language_model_handler_kwargs), + ) + else: + raise ValueError("The LLM should be either transformers or mlx-lm") if module_kwargs.tts == 'parler': torch._inductor.config.fx_graph_cache = True # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS -- GitLab