diff --git a/README.md b/README.md index 822c765b84a455c4f2f3bf2175fe664f91c6ab3e..b604e97e0423a58f197ccc6a0339ff656543c43f 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ pip install -r requirements.txt The pipeline can be run in two ways: - **Server/Client approach**: Models run on a server, and audio input/output are streamed from a client. -- **Local approach**: Uses the same client/server method but with the loopback address. +- **Local approach**: Runs locally. ### Server/Client Approach @@ -63,21 +63,24 @@ To run the pipeline on the server: python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0 ``` -Then run the client locally to handle sending microphone input and receiving generated audio: +Then run the pipeline locally: ```bash -python listen_and_play.py --host <IP address of your server> +python s2s_pipeline.py --mode local ``` -### Local Approach -Simply use the loopback address: +### Running on Mac +To run on mac, we recommend setting the flag `--local_mac_optimal_settings`: ```bash -python s2s_pipeline.py --recv_host localhost --send_host localhost -python listen_and_play.py --host localhost +python s2s_pipeline.py --local_mac_optimal_settings ``` -You can pass `--device mps` to run it locally on a Mac. +You can also pass `--device mps` to have all the models set to device mps. +The local mac optimal settings set the mode to be local as explained above and change the models to: +- LightningWhisperMLX +- MLX LM +- MeloTTS -### Recommended usage +### Recommended usage with Cuda Leverage Torch Compile for Whisper and Parler-TTS: diff --git a/STT/lightning_whisper_mlx_handler.py b/STT/lightning_whisper_mlx_handler.py index 90a646f52c4844462aa87291e9ae717f12db20b7..53709025c7438c559e06a098b70c697922eef6a1 100644 --- a/STT/lightning_whisper_mlx_handler.py +++ b/STT/lightning_whisper_mlx_handler.py @@ -20,15 +20,17 @@ class LightningWhisperSTTHandler(BaseHandler): def setup( self, - model_name="distil-whisper/distil-large-v3", + model_name="distil-large-v3", device="cuda", torch_dtype="float16", compile_mode=None, gen_kwargs={}, ): + if len(model_name.split('/')) > 1: + model_name = model_name.split('/')[-1] self.device = device self.model = LightningWhisperMLX( - model="distil-large-v3", batch_size=6, quant=None + model=model_name, batch_size=6, quant=None ) self.warmup() diff --git a/TTS/melotts.py b/TTS/melotts.py index 70b29f5994202d4309600b8cbe55673917086e14..f1a712bb98fc050c03e114dd24e2d62ae63284b8 100644 --- a/TTS/melotts.py +++ b/TTS/melotts.py @@ -20,12 +20,15 @@ class MeloTTSHandler(BaseHandler): should_listen, device="mps", language="EN_NEWEST", + speaker_to_id="EN-Newest", + gen_kwargs={}, # Unused blocksize=512, ): + print(device) self.should_listen = should_listen self.device = device self.model = TTS(language=language, device=device) - self.speaker_id = self.model.hps.data.spk2id["EN-Newest"] + self.speaker_id = self.model.hps.data.spk2id[speaker_to_id] self.blocksize = blocksize self.warmup() diff --git a/handlers/melo_tts_handler.py b/handlers/melo_tts_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..88616c34e1077b4d922fa820ad530dae6a42ac17 --- /dev/null +++ b/handlers/melo_tts_handler.py @@ -0,0 +1,26 @@ + +from dataclasses import dataclass, field +from typing import List + + +@dataclass +class MeloTTSHandlerArguments: + melo_language: str = field( + default="EN_NEWEST", + metadata={ + "help": "The language of the text to be synthesized. Default is 'EN_NEWEST'." + }, + ) + melo_device: str = field( + default="auto", + metadata={ + "help": "The device to be used for speech synthesis. Default is 'auto'." + }, + ) + melo_speaker_to_id: str = field( + default="EN-Newest", + metadata={ + "help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']." + }, + ) + diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 14a5e2b79919317cd05c9c78372ae3cd1dbe1428..0299189cd8d8f317366aa2aa9a5f229c15e686a3 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -14,6 +14,7 @@ from sys import platform from LLM.mlx_lm import MLXLanguageModelHandler from baseHandler import BaseHandler from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler +from handlers.melo_tts_handler import MeloTTSHandlerArguments import numpy as np import torch import nltk @@ -56,11 +57,23 @@ class ModuleArguments: metadata={"help": "If specified, overrides the device for all handlers."}, ) mode: Optional[str] = field( - default="local", + default="socket", metadata={ "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'." }, ) + local_mac_optimal_settings: bool = field( + default=False, + metadata={ + "help": "If specified, sets the optimal settings for Mac OS. Hence whisper-mlx, MLX LM and MeloTTS will be used." + }, + ) + stt: Optional[str] = field( + default="whisper", + metadata={ + "help": "The STT to use. Either 'whisper' or 'whisper-mlx'. Default is 'whisper'." + }, + ) llm: Optional[str] = field( default="transformers", metadata={ @@ -916,6 +929,7 @@ def main(): WhisperSTTHandlerArguments, LanguageModelHandlerArguments, ParlerTTSHandlerArguments, + MeloTTSHandlerArguments, ) ) @@ -930,6 +944,7 @@ def main(): whisper_stt_handler_kwargs, language_model_handler_kwargs, parler_tts_handler_kwargs, + melo_tts_handler_kwargs, ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: # Parse arguments from command line if no JSON file is provided @@ -941,6 +956,7 @@ def main(): whisper_stt_handler_kwargs, language_model_handler_kwargs, parler_tts_handler_kwargs, + melo_tts_handler_kwargs, ) = parser.parse_args_into_dataclasses() # 1. Handle logger @@ -955,6 +971,26 @@ def main(): if module_kwargs.log_level == "debug": torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) + + def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs): + if mac_optimal_settings: + for kwargs in handler_kwargs: + if hasattr(kwargs, "device"): + kwargs.device = "mps" + if hasattr(kwargs, "mode"): + kwargs.mode = "local" + if hasattr(kwargs, "stt"): + kwargs.stt = "whisper-mlx" + if hasattr(kwargs, "llm"): + kwargs.llm = "mlx-lm" + if hasattr(kwargs, "tts"): + kwargs.tts = "melo" + + optimal_mac_settings( + module_kwargs.local_mac_optimal_settings, + module_kwargs, + ) + if platform == "darwin": if module_kwargs.device == "cuda": raise ValueError( @@ -991,6 +1027,7 @@ def main(): prepare_args(whisper_stt_handler_kwargs, "stt") prepare_args(language_model_handler_kwargs, "lm") prepare_args(parler_tts_handler_kwargs, "tts") + prepare_args(melo_tts_handler_kwargs, "melo") # 3. Build the pipeline stop_event = Event() @@ -1033,12 +1070,22 @@ def main(): setup_args=(should_listen,), setup_kwargs=vars(vad_handler_kwargs), ) - stt = LightningWhisperSTTHandler( - stop_event, + if module_kwargs.stt == 'whisper': + stt = WhisperSTTHandler( + stop_event, queue_in=spoken_prompt_queue, queue_out=text_prompt_queue, setup_kwargs=vars(whisper_stt_handler_kwargs), ) + elif module_kwargs.stt == 'whisper-mlx': + stt = LightningWhisperSTTHandler( + stop_event, + queue_in=spoken_prompt_queue, + queue_out=text_prompt_queue, + setup_kwargs=vars(whisper_stt_handler_kwargs), + ) + else: + raise ValueError("The STT should be either whisper or whisper-mlx") if module_kwargs.llm == 'transformers': lm = LanguageModelHandler( stop_event, @@ -1078,6 +1125,7 @@ def main(): queue_in=lm_response_queue, queue_out=send_audio_chunks_queue, setup_args=(should_listen,), + setup_kwargs=vars(melo_tts_handler_kwargs), ) else: raise ValueError("The TTS should be either parler or melo")