diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 9e5bb0493f6195e1d7fc31f5f4a87dc5b0da014f..4c0cac1d4658bf0b44505f44664f63b5e275e4ff 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -49,11 +49,10 @@ console = Console() logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs -def prepare_args(args, prefix): +def rename_args(args, prefix): """ Rename arguments by removing the prefix and prepares the gen_kwargs. """ - gen_kwargs = {} for key in copy(args.__dict__): if key.startswith(prefix): @@ -67,7 +66,7 @@ def prepare_args(args, prefix): args.__dict__["gen_kwargs"] = gen_kwargs -def main(): +def parse_arguments(): parser = HfArgumentParser( ( ModuleArguments, @@ -84,69 +83,43 @@ def main(): ) ) - # 0. Parse CLI arguments if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Parse configurations from a JSON file if specified - ( - module_kwargs, - socket_receiver_kwargs, - socket_sender_kwargs, - vad_handler_kwargs, - whisper_stt_handler_kwargs, - paraformer_stt_handler_kwargs, - language_model_handler_kwargs, - mlx_language_model_handler_kwargs, - parler_tts_handler_kwargs, - melo_tts_handler_kwargs, - chat_tts_handler_kwargs, - ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + return parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: # Parse arguments from command line if no JSON file is provided - ( - module_kwargs, - socket_receiver_kwargs, - socket_sender_kwargs, - vad_handler_kwargs, - whisper_stt_handler_kwargs, - paraformer_stt_handler_kwargs, - language_model_handler_kwargs, - mlx_language_model_handler_kwargs, - parler_tts_handler_kwargs, - melo_tts_handler_kwargs, - chat_tts_handler_kwargs, - ) = parser.parse_args_into_dataclasses() - - # 1. Handle logger + return parser.parse_args_into_dataclasses() + + +def setup_logger(log_level): global logger logging.basicConfig( - level=module_kwargs.log_level.upper(), + level=log_level.upper(), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # torch compile logs - if module_kwargs.log_level == "debug": + if log_level == "debug": torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) - def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs): - if mac_optimal_settings: - for kwargs in handler_kwargs: - if hasattr(kwargs, "device"): - kwargs.device = "mps" - if hasattr(kwargs, "mode"): - kwargs.mode = "local" - if hasattr(kwargs, "stt"): - kwargs.stt = "whisper-mlx" - if hasattr(kwargs, "llm"): - kwargs.llm = "mlx-lm" - if hasattr(kwargs, "tts"): - kwargs.tts = "melo" - - optimal_mac_settings( - module_kwargs.local_mac_optimal_settings, - module_kwargs, - ) +def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs): + if mac_optimal_settings: + for kwargs in handler_kwargs: + if hasattr(kwargs, "device"): + kwargs.device = "mps" + if hasattr(kwargs, "mode"): + kwargs.mode = "local" + if hasattr(kwargs, "stt"): + kwargs.stt = "whisper-mlx" + if hasattr(kwargs, "llm"): + kwargs.llm = "mlx-lm" + if hasattr(kwargs, "tts"): + kwargs.tts = "melo" + + +def check_mac_settings(module_kwargs): if platform == "darwin": if module_kwargs.device == "cuda": raise ValueError( @@ -161,46 +134,90 @@ def main(): "If you experiences issues generating the voice, considering setting the tts to melo." ) - # 2. Prepare each part's arguments - def overwrite_device_argument(common_device: Optional[str], *handler_kwargs): - if common_device: - for kwargs in handler_kwargs: - if hasattr(kwargs, "lm_device"): - kwargs.lm_device = common_device - if hasattr(kwargs, "tts_device"): - kwargs.tts_device = common_device - if hasattr(kwargs, "stt_device"): - kwargs.stt_device = common_device - if hasattr(kwargs, "paraformer_stt_device"): - kwargs.paraformer_stt_device = common_device - - # Call this function with the common device and all the handlers - overwrite_device_argument( - module_kwargs.device, + +def overwrite_device_argument(common_device: Optional[str], *handler_kwargs): + if common_device: + for kwargs in handler_kwargs: + if hasattr(kwargs, "lm_device"): + kwargs.lm_device = common_device + if hasattr(kwargs, "tts_device"): + kwargs.tts_device = common_device + if hasattr(kwargs, "stt_device"): + kwargs.stt_device = common_device + if hasattr(kwargs, "paraformer_stt_device"): + kwargs.paraformer_stt_device = common_device + + +def prepare_module_args(module_kwargs, *handler_kwargs): + optimal_mac_settings(module_kwargs.local_mac_optimal_settings, module_kwargs) + if platform == "darwin": + check_mac_settings(module_kwargs) + overwrite_device_argument(module_kwargs.device, *handler_kwargs) + + +def prepare_all_args( + module_kwargs, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, + language_model_handler_kwargs, + mlx_language_model_handler_kwargs, + parler_tts_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs, +): + prepare_module_args( + module_kwargs, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, - whisper_stt_handler_kwargs, - paraformer_stt_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs, ) - prepare_args(whisper_stt_handler_kwargs, "stt") - prepare_args(paraformer_stt_handler_kwargs, "paraformer_stt") - prepare_args(language_model_handler_kwargs, "lm") - prepare_args(mlx_language_model_handler_kwargs, "mlx_lm") - prepare_args(parler_tts_handler_kwargs, "tts") - prepare_args(melo_tts_handler_kwargs, "melo") - prepare_args(chat_tts_handler_kwargs, "chat_tts") - - # 3. Build the pipeline - stop_event = Event() - # used to stop putting received audio chunks in queue until all setences have been processed by the TTS - should_listen = Event() - recv_audio_chunks_queue = Queue() - send_audio_chunks_queue = Queue() - spoken_prompt_queue = Queue() - text_prompt_queue = Queue() - lm_response_queue = Queue() + rename_args(whisper_stt_handler_kwargs, "stt") + rename_args(paraformer_stt_handler_kwargs, "paraformer_stt") + rename_args(language_model_handler_kwargs, "lm") + rename_args(mlx_language_model_handler_kwargs, "mlx_lm") + rename_args(parler_tts_handler_kwargs, "tts") + rename_args(melo_tts_handler_kwargs, "melo") + rename_args(chat_tts_handler_kwargs, "chat_tts") + + +def initialize_queues_and_events(): + return { + "stop_event": Event(), + "should_listen": Event(), + "recv_audio_chunks_queue": Queue(), + "send_audio_chunks_queue": Queue(), + "spoken_prompt_queue": Queue(), + "text_prompt_queue": Queue(), + "lm_response_queue": Queue(), + } + + +def build_pipeline( + module_kwargs, + socket_receiver_kwargs, + socket_sender_kwargs, + vad_handler_kwargs, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, + language_model_handler_kwargs, + mlx_language_model_handler_kwargs, + parler_tts_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs, + queues_and_events, +): + stop_event = queues_and_events["stop_event"] + should_listen = queues_and_events["should_listen"] + recv_audio_chunks_queue = queues_and_events["recv_audio_chunks_queue"] + send_audio_chunks_queue = queues_and_events["send_audio_chunks_queue"] + spoken_prompt_queue = queues_and_events["spoken_prompt_queue"] + text_prompt_queue = queues_and_events["text_prompt_queue"] + lm_response_queue = queues_and_events["lm_response_queue"] if module_kwargs.mode == "local": from connections.local_audio_streamer import LocalAudioStreamer @@ -238,10 +255,18 @@ def main(): setup_args=(should_listen,), setup_kwargs=vars(vad_handler_kwargs), ) + + stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs) + lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs) + tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs) + + return ThreadManager([*comms_handlers, vad, stt, lm, tts]) + + +def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs): if module_kwargs.stt == "whisper": from STT.whisper_stt_handler import WhisperSTTHandler - - stt = WhisperSTTHandler( + return WhisperSTTHandler( stop_event, queue_in=spoken_prompt_queue, queue_out=text_prompt_queue, @@ -249,8 +274,7 @@ def main(): ) elif module_kwargs.stt == "whisper-mlx": from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler - - stt = LightningWhisperSTTHandler( + return LightningWhisperSTTHandler( stop_event, queue_in=spoken_prompt_queue, queue_out=text_prompt_queue, @@ -258,21 +282,20 @@ def main(): ) elif module_kwargs.stt == "paraformer": from STT.paraformer_handler import ParaformerSTTHandler - - stt = ParaformerSTTHandler( + return ParaformerSTTHandler( stop_event, queue_in=spoken_prompt_queue, queue_out=text_prompt_queue, setup_kwargs=vars(paraformer_stt_handler_kwargs), ) else: - raise ValueError( - "The STT should be either whisper, whisper-mlx, or paraformer." - ) + raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.") + + +def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs): if module_kwargs.llm == "transformers": from LLM.language_model import LanguageModelHandler - - lm = LanguageModelHandler( + return LanguageModelHandler( stop_event, queue_in=text_prompt_queue, queue_out=lm_response_queue, @@ -280,8 +303,7 @@ def main(): ) elif module_kwargs.llm == "mlx-lm": from LLM.mlx_language_model import MLXLanguageModelHandler - - lm = MLXLanguageModelHandler( + return MLXLanguageModelHandler( stop_event, queue_in=text_prompt_queue, queue_out=lm_response_queue, @@ -289,10 +311,12 @@ def main(): ) else: raise ValueError("The LLM should be either transformers or mlx-lm") + + +def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs): if module_kwargs.tts == "parler": from TTS.parler_handler import ParlerTTSHandler - - tts = ParlerTTSHandler( + return ParlerTTSHandler( stop_event, queue_in=lm_response_queue, queue_out=send_audio_chunks_queue, @@ -307,7 +331,7 @@ def main(): "Error importing MeloTTSHandler. You might need to run: python -m unidic download" ) raise e - tts = MeloTTSHandler( + return MeloTTSHandler( stop_event, queue_in=lm_response_queue, queue_out=send_audio_chunks_queue, @@ -320,7 +344,7 @@ def main(): except RuntimeError as e: logger.error("Error importing ChatTTSHandler") raise e - tts = ChatTTSHandler( + return ChatTTSHandler( stop_event, queue_in=lm_response_queue, queue_out=send_audio_chunks_queue, @@ -330,14 +354,57 @@ def main(): else: raise ValueError("The TTS should be either parler, melo or chatTTS") - # 4. Run the pipeline + +def main(): + ( + module_kwargs, + socket_receiver_kwargs, + socket_sender_kwargs, + vad_handler_kwargs, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, + language_model_handler_kwargs, + mlx_language_model_handler_kwargs, + parler_tts_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs, + ) = parse_arguments() + + setup_logger(module_kwargs.log_level) + + prepare_all_args( + module_kwargs, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, + language_model_handler_kwargs, + mlx_language_model_handler_kwargs, + parler_tts_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs, + ) + + queues_and_events = initialize_queues_and_events() + + pipeline_manager = build_pipeline( + module_kwargs, + socket_receiver_kwargs, + socket_sender_kwargs, + vad_handler_kwargs, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, + language_model_handler_kwargs, + mlx_language_model_handler_kwargs, + parler_tts_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs, + queues_and_events, + ) + try: - pipeline_manager = ThreadManager([*comms_handlers, vad, stt, lm, tts]) pipeline_manager.start() - except KeyboardInterrupt: pipeline_manager.stop() if __name__ == "__main__": - main() + main() \ No newline at end of file