diff --git a/listen_and_play.py b/listen_and_play.py index 398cb32eef70d1a63ad872e1859bcc37b636a4b9..179c9494537d1d72f0c4ecad3c11bb9ba222248b 100644 --- a/listen_and_play.py +++ b/listen_and_play.py @@ -133,6 +133,6 @@ def listen_and_play( if __name__ == "__main__": parser = HfArgumentParser((ListenAndPlayArguments,)) - listen_and_play_kwargs = parser.parse_args_into_dataclasses() + listen_and_play_kwargs, = parser.parse_args_into_dataclasses() listen_and_play(**vars(listen_and_play_kwargs)) diff --git a/s2s_pipeline.py b/s2s_pipeline.py index f89a55094c76d26c610e0feaade60ca2f6794ea3..e9cc5d9c1a90dd223fa98d57d694a47b4700186e 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -46,15 +46,6 @@ class ModuleArguments: "help": "Provide logging level. Example --log_level debug, default=warning." } ) - client: bool = field( - default=False, - metadata={"help": "Whether or not this module is run as client."} - ) - server: bool = field( - default=False, - metadata={"help": "Whether or not this module is run as server."} - ) - class ThreadManager: def __init__(self, handlers): @@ -688,83 +679,63 @@ def main(): module_kwargs.client = True module_kwargs.server = True - if module_kwargs.client and module_kwargs.server and (socket_receiver_kwargs.recv_host != "127.0.0.1" or socket_sender_kwargs.recv_host != "localhost"): - raise ValueError() + stop_event = Event() + should_listen = Event() + recv_audio_chunks_queue = Queue() + send_audio_chunks_queue = Queue() + spoken_prompt_queue = Queue() + text_prompt_queue = Queue() + llm_response_queue = Queue() + vad = VADHandler( + stop_event, + queue_in=recv_audio_chunks_queue, + queue_out=spoken_prompt_queue, + setup_args=(should_listen,), + setup_kwargs=vars(vad_handler_kwargs), + ) + stt = WhisperSTTHandler( + stop_event, + queue_in=spoken_prompt_queue, + queue_out=text_prompt_queue, + setup_kwargs=vars(whisper_stt_handler_kwargs), + ) + llm = LanguageModelHandler( + stop_event, + queue_in=text_prompt_queue, + queue_out=llm_response_queue, + setup_kwargs=vars(language_model_handler_kwargs), + ) + tts = ParlerTTSHandler( + stop_event, + queue_in=llm_response_queue, + queue_out=send_audio_chunks_queue, + setup_args=(should_listen,), + setup_kwargs=vars(parler_tts_handler_kwargs), + ) + + recv_handler = SocketReceiver( + stop_event, + recv_audio_chunks_queue, + should_listen, + host=socket_receiver_kwargs.recv_host, + port=socket_receiver_kwargs.recv_port, + chunk_size=socket_receiver_kwargs.chunk_size, + ) + + send_handler = SocketSender( + stop_event, + send_audio_chunks_queue, + host=socket_sender_kwargs.send_host, + port=socket_sender_kwargs.send_port, + ) + try: - if module_kwargs.server: - stop_event = Event() - should_listen = Event() - recv_audio_chunks_queue = Queue() - send_audio_chunks_queue = Queue() - spoken_prompt_queue = Queue() - text_prompt_queue = Queue() - llm_response_queue = Queue() - - vad = VADHandler( - stop_event, - queue_in=recv_audio_chunks_queue, - queue_out=spoken_prompt_queue, - setup_args=(should_listen,), - setup_kwargs=vars(vad_handler_kwargs), - ) - stt = WhisperSTTHandler( - stop_event, - queue_in=spoken_prompt_queue, - queue_out=text_prompt_queue, - setup_kwargs=vars(whisper_stt_handler_kwargs), - ) - llm = LanguageModelHandler( - stop_event, - queue_in=text_prompt_queue, - queue_out=llm_response_queue, - setup_kwargs=vars(language_model_handler_kwargs), - ) - tts = ParlerTTSHandler( - stop_event, - queue_in=llm_response_queue, - queue_out=send_audio_chunks_queue, - setup_args=(should_listen,), - setup_kwargs=vars(parler_tts_handler_kwargs), - ) - - recv_handler = SocketReceiver( - stop_event, - recv_audio_chunks_queue, - should_listen, - host=socket_receiver_kwargs.recv_host, - port=socket_receiver_kwargs.recv_port, - chunk_size=socket_receiver_kwargs.chunk_size, - ) - - send_handler = SocketSender( - stop_event, - send_audio_chunks_queue, - host=socket_sender_kwargs.send_host, - port=socket_sender_kwargs.send_port, - ) - - pipeline_manager = ThreadManager([vad, tts, llm, stt, recv_handler, send_handler]) - pipeline_manager.start() - - if module_kwargs.client: - from listen_and_play import listen_and_play - - kwargs = { - "host": "localhost", - "send_port": socket_sender_kwargs.send_port, - "recv_port": socket_receiver_kwargs.recv_port, - } - - listen_and_play_process = multiprocessing.Process( - target=listen_and_play, - kwargs=kwargs, - ) - listen_and_play_process.start() + pipeline_manager = ThreadManager([vad, tts, llm, stt, recv_handler, send_handler]) + pipeline_manager.start() except KeyboardInterrupt: pipeline_manager.stop() - listen_and_play_process.join() if __name__ == "__main__": main()