From a4888572b535d92eb7803e85a31130b6d605fdff Mon Sep 17 00:00:00 2001
From: Andres Marafioti <andimarafioti@gmail.com>
Date: Thu, 22 Aug 2024 13:55:48 +0200
Subject: [PATCH] Let users choose their llm

---
 s2s_pipeline.py | 34 ++++++++++++++++++++++++++++++++--
 1 file changed, 32 insertions(+), 2 deletions(-)

diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index b235c65..66aeaa5 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -10,7 +10,7 @@ from queue import Queue
 from threading import Event, Thread
 from time import perf_counter
 from typing import Optional
-
+from sys import platform
 from LLM.mlx_lm import MLXLanguageModelHandler
 from TTS.melotts import MeloTTSHandler
 from baseHandler import BaseHandler
@@ -62,6 +62,12 @@ class ModuleArguments:
             "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'."
         },
     )
+    llm: Optional[str] =  field(
+        default="transformers",
+        metadata={
+            "help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'"
+        },
+    )
     tts: Optional[str] =  field(
         default="parler",
         metadata={
@@ -950,6 +956,20 @@ def main():
     if module_kwargs.log_level == "debug":
         torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
 
+    if platform == "darwin":
+        if module_kwargs.device == "cuda":
+            raise ValueError(
+                "Cannot use CUDA on macOS. Please set the device to 'cpu' or 'mps'."
+            )
+        if module_kwargs.llm != "mlx-lm":
+            logger.warning(
+                "For macOS users, it is recommended to use mlx-lm."
+            )
+        if module_kwargs.tts != "melo":
+            logger.warning(
+                "If you experiences issues generating the voice, considering setting the tts to melo."
+            )
+
     # 2. Prepare each part's arguments
     def overwrite_device_argument(common_device: Optional[str], *handler_kwargs):
         if common_device:
@@ -1020,12 +1040,22 @@ def main():
         queue_out=text_prompt_queue,
         setup_kwargs=vars(whisper_stt_handler_kwargs),
     )
-    lm = MLXLanguageModelHandler(
+    if module_kwargs.llm == 'transformers':
+        lm = LanguageModelHandler(
         stop_event,
         queue_in=text_prompt_queue,
         queue_out=lm_response_queue,
         setup_kwargs=vars(language_model_handler_kwargs),
     )
+    elif module_kwargs.llm == 'mlx-lm':
+        lm = MLXLanguageModelHandler(
+            stop_event,
+            queue_in=text_prompt_queue,
+            queue_out=lm_response_queue,
+            setup_kwargs=vars(language_model_handler_kwargs),
+        )
+    else:
+        raise ValueError("The LLM should be either transformers or mlx-lm")
     if module_kwargs.tts == 'parler':
         torch._inductor.config.fx_graph_cache = True
         # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
-- 
GitLab