From dc926fce74c0f46399b879111d02f15af7020af6 Mon Sep 17 00:00:00 2001
From: Andres Marafioti <andimarafioti@gmail.com>
Date: Thu, 22 Aug 2024 17:30:08 +0200
Subject: [PATCH] last changes

---
 README.md                            | 21 ++++++-----
 STT/lightning_whisper_mlx_handler.py |  6 ++--
 TTS/melotts.py                       |  5 ++-
 handlers/melo_tts_handler.py         | 26 ++++++++++++++
 s2s_pipeline.py                      | 54 ++++++++++++++++++++++++++--
 5 files changed, 97 insertions(+), 15 deletions(-)
 create mode 100644 handlers/melo_tts_handler.py

diff --git a/README.md b/README.md
index 822c765..b604e97 100644
--- a/README.md
+++ b/README.md
@@ -54,7 +54,7 @@ pip install -r requirements.txt
 
 The pipeline can be run in two ways:
 - **Server/Client approach**: Models run on a server, and audio input/output are streamed from a client.
-- **Local approach**: Uses the same client/server method but with the loopback address.
+- **Local approach**: Runs locally.
 
 ### Server/Client Approach
 
@@ -63,21 +63,24 @@ To run the pipeline on the server:
 python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
 ```
 
-Then run the client locally to handle sending microphone input and receiving generated audio:
+Then run the pipeline locally:
 ```bash
-python listen_and_play.py --host <IP address of your server>
+python s2s_pipeline.py --mode local
 ```
 
-### Local Approach
-Simply use the loopback address:
+### Running on Mac
+To run on mac, we recommend setting the flag `--local_mac_optimal_settings`:
 ```bash
-python s2s_pipeline.py --recv_host localhost --send_host localhost
-python listen_and_play.py --host localhost
+python s2s_pipeline.py --local_mac_optimal_settings
 ```
 
-You can pass `--device mps` to run it locally on a Mac.
+You can also pass `--device mps` to have all the models set to device mps.
+The local mac optimal settings set the mode to be local as explained above and change the models to:
+- LightningWhisperMLX
+- MLX LM
+- MeloTTS
 
-### Recommended usage
+### Recommended usage with Cuda
 
 Leverage Torch Compile for Whisper and Parler-TTS:
 
diff --git a/STT/lightning_whisper_mlx_handler.py b/STT/lightning_whisper_mlx_handler.py
index 90a646f..5370902 100644
--- a/STT/lightning_whisper_mlx_handler.py
+++ b/STT/lightning_whisper_mlx_handler.py
@@ -20,15 +20,17 @@ class LightningWhisperSTTHandler(BaseHandler):
 
     def setup(
         self,
-        model_name="distil-whisper/distil-large-v3",
+        model_name="distil-large-v3",
         device="cuda",
         torch_dtype="float16",
         compile_mode=None,
         gen_kwargs={},
     ):
+        if len(model_name.split('/')) > 1:
+            model_name = model_name.split('/')[-1]
         self.device = device
         self.model = LightningWhisperMLX(
-            model="distil-large-v3", batch_size=6, quant=None
+            model=model_name, batch_size=6, quant=None
         )
         self.warmup()
 
diff --git a/TTS/melotts.py b/TTS/melotts.py
index 70b29f5..f1a712b 100644
--- a/TTS/melotts.py
+++ b/TTS/melotts.py
@@ -20,12 +20,15 @@ class MeloTTSHandler(BaseHandler):
         should_listen,
         device="mps",
         language="EN_NEWEST",
+        speaker_to_id="EN-Newest",
+        gen_kwargs={},  # Unused
         blocksize=512,
     ):
+        print(device)
         self.should_listen = should_listen
         self.device = device
         self.model = TTS(language=language, device=device)
-        self.speaker_id = self.model.hps.data.spk2id["EN-Newest"]
+        self.speaker_id = self.model.hps.data.spk2id[speaker_to_id]
         self.blocksize = blocksize
         self.warmup()
 
diff --git a/handlers/melo_tts_handler.py b/handlers/melo_tts_handler.py
new file mode 100644
index 0000000..88616c3
--- /dev/null
+++ b/handlers/melo_tts_handler.py
@@ -0,0 +1,26 @@
+
+from dataclasses import dataclass, field
+from typing import List
+
+
+@dataclass
+class MeloTTSHandlerArguments:
+    melo_language: str = field(
+        default="EN_NEWEST",
+        metadata={
+            "help": "The language of the text to be synthesized. Default is 'EN_NEWEST'."
+        },
+    )
+    melo_device: str = field(
+        default="auto",
+        metadata={
+            "help": "The device to be used for speech synthesis. Default is 'auto'."
+        },
+    )
+    melo_speaker_to_id: str = field(
+        default="EN-Newest",
+        metadata={
+            "help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']."
+        },
+    )
+
diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index 14a5e2b..0299189 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -14,6 +14,7 @@ from sys import platform
 from LLM.mlx_lm import MLXLanguageModelHandler
 from baseHandler import BaseHandler
 from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler
+from handlers.melo_tts_handler import MeloTTSHandlerArguments
 import numpy as np
 import torch
 import nltk
@@ -56,11 +57,23 @@ class ModuleArguments:
         metadata={"help": "If specified, overrides the device for all handlers."},
     )
     mode: Optional[str] = field(
-        default="local",
+        default="socket",
         metadata={
             "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'."
         },
     )
+    local_mac_optimal_settings: bool = field(
+        default=False,
+        metadata={
+            "help": "If specified, sets the optimal settings for Mac OS. Hence whisper-mlx, MLX LM and MeloTTS will be used."
+        },
+    )
+    stt: Optional[str] = field(
+        default="whisper",
+        metadata={
+            "help": "The STT to use. Either 'whisper' or 'whisper-mlx'. Default is 'whisper'."
+        },
+    )
     llm: Optional[str] =  field(
         default="transformers",
         metadata={
@@ -916,6 +929,7 @@ def main():
             WhisperSTTHandlerArguments,
             LanguageModelHandlerArguments,
             ParlerTTSHandlerArguments,
+            MeloTTSHandlerArguments,
         )
     )
 
@@ -930,6 +944,7 @@ def main():
             whisper_stt_handler_kwargs,
             language_model_handler_kwargs,
             parler_tts_handler_kwargs,
+            melo_tts_handler_kwargs,
         ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
     else:
         # Parse arguments from command line if no JSON file is provided
@@ -941,6 +956,7 @@ def main():
             whisper_stt_handler_kwargs,
             language_model_handler_kwargs,
             parler_tts_handler_kwargs,
+            melo_tts_handler_kwargs,
         ) = parser.parse_args_into_dataclasses()
 
     # 1. Handle logger
@@ -955,6 +971,26 @@ def main():
     if module_kwargs.log_level == "debug":
         torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
 
+
+    def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs):
+        if mac_optimal_settings:
+            for kwargs in handler_kwargs:
+                if hasattr(kwargs, "device"):
+                    kwargs.device = "mps"
+                if hasattr(kwargs, "mode"):
+                    kwargs.mode = "local"
+                if hasattr(kwargs, "stt"):
+                    kwargs.stt = "whisper-mlx"
+                if hasattr(kwargs, "llm"):
+                    kwargs.llm = "mlx-lm"
+                if hasattr(kwargs, "tts"):
+                    kwargs.tts = "melo"
+
+    optimal_mac_settings(
+        module_kwargs.local_mac_optimal_settings,
+        module_kwargs,
+    )
+
     if platform == "darwin":
         if module_kwargs.device == "cuda":
             raise ValueError(
@@ -991,6 +1027,7 @@ def main():
     prepare_args(whisper_stt_handler_kwargs, "stt")
     prepare_args(language_model_handler_kwargs, "lm")
     prepare_args(parler_tts_handler_kwargs, "tts")
+    prepare_args(melo_tts_handler_kwargs, "melo")
 
     # 3. Build the pipeline
     stop_event = Event()
@@ -1033,12 +1070,22 @@ def main():
         setup_args=(should_listen,),
         setup_kwargs=vars(vad_handler_kwargs),
     )
-    stt = LightningWhisperSTTHandler(
-        stop_event,
+    if module_kwargs.stt == 'whisper':
+        stt = WhisperSTTHandler(
+           stop_event,
         queue_in=spoken_prompt_queue,
         queue_out=text_prompt_queue,
         setup_kwargs=vars(whisper_stt_handler_kwargs),
     )
+    elif module_kwargs.stt == 'whisper-mlx':
+        stt = LightningWhisperSTTHandler(
+            stop_event,
+            queue_in=spoken_prompt_queue,
+            queue_out=text_prompt_queue,
+            setup_kwargs=vars(whisper_stt_handler_kwargs),
+        )
+    else:
+        raise ValueError("The STT should be either whisper or whisper-mlx")
     if module_kwargs.llm == 'transformers':
         lm = LanguageModelHandler(
         stop_event,
@@ -1078,6 +1125,7 @@ def main():
             queue_in=lm_response_queue,
             queue_out=send_audio_chunks_queue,
             setup_args=(should_listen,),
+            setup_kwargs=vars(melo_tts_handler_kwargs),
         )
     else:
         raise ValueError("The TTS should be either parler or melo")
-- 
GitLab