From dc926fce74c0f46399b879111d02f15af7020af6 Mon Sep 17 00:00:00 2001 From: Andres Marafioti <andimarafioti@gmail.com> Date: Thu, 22 Aug 2024 17:30:08 +0200 Subject: [PATCH] last changes --- README.md | 21 ++++++----- STT/lightning_whisper_mlx_handler.py | 6 ++-- TTS/melotts.py | 5 ++- handlers/melo_tts_handler.py | 26 ++++++++++++++ s2s_pipeline.py | 54 ++++++++++++++++++++++++++-- 5 files changed, 97 insertions(+), 15 deletions(-) create mode 100644 handlers/melo_tts_handler.py diff --git a/README.md b/README.md index 822c765..b604e97 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ pip install -r requirements.txt The pipeline can be run in two ways: - **Server/Client approach**: Models run on a server, and audio input/output are streamed from a client. -- **Local approach**: Uses the same client/server method but with the loopback address. +- **Local approach**: Runs locally. ### Server/Client Approach @@ -63,21 +63,24 @@ To run the pipeline on the server: python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0 ``` -Then run the client locally to handle sending microphone input and receiving generated audio: +Then run the pipeline locally: ```bash -python listen_and_play.py --host <IP address of your server> +python s2s_pipeline.py --mode local ``` -### Local Approach -Simply use the loopback address: +### Running on Mac +To run on mac, we recommend setting the flag `--local_mac_optimal_settings`: ```bash -python s2s_pipeline.py --recv_host localhost --send_host localhost -python listen_and_play.py --host localhost +python s2s_pipeline.py --local_mac_optimal_settings ``` -You can pass `--device mps` to run it locally on a Mac. +You can also pass `--device mps` to have all the models set to device mps. +The local mac optimal settings set the mode to be local as explained above and change the models to: +- LightningWhisperMLX +- MLX LM +- MeloTTS -### Recommended usage +### Recommended usage with Cuda Leverage Torch Compile for Whisper and Parler-TTS: diff --git a/STT/lightning_whisper_mlx_handler.py b/STT/lightning_whisper_mlx_handler.py index 90a646f..5370902 100644 --- a/STT/lightning_whisper_mlx_handler.py +++ b/STT/lightning_whisper_mlx_handler.py @@ -20,15 +20,17 @@ class LightningWhisperSTTHandler(BaseHandler): def setup( self, - model_name="distil-whisper/distil-large-v3", + model_name="distil-large-v3", device="cuda", torch_dtype="float16", compile_mode=None, gen_kwargs={}, ): + if len(model_name.split('/')) > 1: + model_name = model_name.split('/')[-1] self.device = device self.model = LightningWhisperMLX( - model="distil-large-v3", batch_size=6, quant=None + model=model_name, batch_size=6, quant=None ) self.warmup() diff --git a/TTS/melotts.py b/TTS/melotts.py index 70b29f5..f1a712b 100644 --- a/TTS/melotts.py +++ b/TTS/melotts.py @@ -20,12 +20,15 @@ class MeloTTSHandler(BaseHandler): should_listen, device="mps", language="EN_NEWEST", + speaker_to_id="EN-Newest", + gen_kwargs={}, # Unused blocksize=512, ): + print(device) self.should_listen = should_listen self.device = device self.model = TTS(language=language, device=device) - self.speaker_id = self.model.hps.data.spk2id["EN-Newest"] + self.speaker_id = self.model.hps.data.spk2id[speaker_to_id] self.blocksize = blocksize self.warmup() diff --git a/handlers/melo_tts_handler.py b/handlers/melo_tts_handler.py new file mode 100644 index 0000000..88616c3 --- /dev/null +++ b/handlers/melo_tts_handler.py @@ -0,0 +1,26 @@ + +from dataclasses import dataclass, field +from typing import List + + +@dataclass +class MeloTTSHandlerArguments: + melo_language: str = field( + default="EN_NEWEST", + metadata={ + "help": "The language of the text to be synthesized. Default is 'EN_NEWEST'." + }, + ) + melo_device: str = field( + default="auto", + metadata={ + "help": "The device to be used for speech synthesis. Default is 'auto'." + }, + ) + melo_speaker_to_id: str = field( + default="EN-Newest", + metadata={ + "help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']." + }, + ) + diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 14a5e2b..0299189 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -14,6 +14,7 @@ from sys import platform from LLM.mlx_lm import MLXLanguageModelHandler from baseHandler import BaseHandler from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler +from handlers.melo_tts_handler import MeloTTSHandlerArguments import numpy as np import torch import nltk @@ -56,11 +57,23 @@ class ModuleArguments: metadata={"help": "If specified, overrides the device for all handlers."}, ) mode: Optional[str] = field( - default="local", + default="socket", metadata={ "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'." }, ) + local_mac_optimal_settings: bool = field( + default=False, + metadata={ + "help": "If specified, sets the optimal settings for Mac OS. Hence whisper-mlx, MLX LM and MeloTTS will be used." + }, + ) + stt: Optional[str] = field( + default="whisper", + metadata={ + "help": "The STT to use. Either 'whisper' or 'whisper-mlx'. Default is 'whisper'." + }, + ) llm: Optional[str] = field( default="transformers", metadata={ @@ -916,6 +929,7 @@ def main(): WhisperSTTHandlerArguments, LanguageModelHandlerArguments, ParlerTTSHandlerArguments, + MeloTTSHandlerArguments, ) ) @@ -930,6 +944,7 @@ def main(): whisper_stt_handler_kwargs, language_model_handler_kwargs, parler_tts_handler_kwargs, + melo_tts_handler_kwargs, ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: # Parse arguments from command line if no JSON file is provided @@ -941,6 +956,7 @@ def main(): whisper_stt_handler_kwargs, language_model_handler_kwargs, parler_tts_handler_kwargs, + melo_tts_handler_kwargs, ) = parser.parse_args_into_dataclasses() # 1. Handle logger @@ -955,6 +971,26 @@ def main(): if module_kwargs.log_level == "debug": torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) + + def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs): + if mac_optimal_settings: + for kwargs in handler_kwargs: + if hasattr(kwargs, "device"): + kwargs.device = "mps" + if hasattr(kwargs, "mode"): + kwargs.mode = "local" + if hasattr(kwargs, "stt"): + kwargs.stt = "whisper-mlx" + if hasattr(kwargs, "llm"): + kwargs.llm = "mlx-lm" + if hasattr(kwargs, "tts"): + kwargs.tts = "melo" + + optimal_mac_settings( + module_kwargs.local_mac_optimal_settings, + module_kwargs, + ) + if platform == "darwin": if module_kwargs.device == "cuda": raise ValueError( @@ -991,6 +1027,7 @@ def main(): prepare_args(whisper_stt_handler_kwargs, "stt") prepare_args(language_model_handler_kwargs, "lm") prepare_args(parler_tts_handler_kwargs, "tts") + prepare_args(melo_tts_handler_kwargs, "melo") # 3. Build the pipeline stop_event = Event() @@ -1033,12 +1070,22 @@ def main(): setup_args=(should_listen,), setup_kwargs=vars(vad_handler_kwargs), ) - stt = LightningWhisperSTTHandler( - stop_event, + if module_kwargs.stt == 'whisper': + stt = WhisperSTTHandler( + stop_event, queue_in=spoken_prompt_queue, queue_out=text_prompt_queue, setup_kwargs=vars(whisper_stt_handler_kwargs), ) + elif module_kwargs.stt == 'whisper-mlx': + stt = LightningWhisperSTTHandler( + stop_event, + queue_in=spoken_prompt_queue, + queue_out=text_prompt_queue, + setup_kwargs=vars(whisper_stt_handler_kwargs), + ) + else: + raise ValueError("The STT should be either whisper or whisper-mlx") if module_kwargs.llm == 'transformers': lm = LanguageModelHandler( stop_event, @@ -1078,6 +1125,7 @@ def main(): queue_in=lm_response_queue, queue_out=send_audio_chunks_queue, setup_args=(should_listen,), + setup_kwargs=vars(melo_tts_handler_kwargs), ) else: raise ValueError("The TTS should be either parler or melo") -- GitLab