From 7a1f9126c76f332d17ca291919883b8a9bb9ba1d Mon Sep 17 00:00:00 2001
From: Eustache Le Bihan <eulebihan@gmail.com>
Date: Mon, 12 Aug 2024 16:15:38 +0000
Subject: [PATCH] fix client server

---
 listen_and_play.py |   2 +-
 s2s_pipeline.py    | 133 ++++++++++++++++++---------------------------
 2 files changed, 53 insertions(+), 82 deletions(-)

diff --git a/listen_and_play.py b/listen_and_play.py
index 398cb32..179c949 100644
--- a/listen_and_play.py
+++ b/listen_and_play.py
@@ -133,6 +133,6 @@ def listen_and_play(
 
 if __name__ == "__main__":
     parser = HfArgumentParser((ListenAndPlayArguments,))
-    listen_and_play_kwargs = parser.parse_args_into_dataclasses()
+    listen_and_play_kwargs, = parser.parse_args_into_dataclasses()
     listen_and_play(**vars(listen_and_play_kwargs))
 
diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index f89a550..e9cc5d9 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -46,15 +46,6 @@ class ModuleArguments:
             "help": "Provide logging level. Example --log_level debug, default=warning."
         }
     )
-    client: bool = field(
-        default=False,
-        metadata={"help": "Whether or not this module is run as client."}
-    )
-    server: bool = field(
-        default=False,
-        metadata={"help": "Whether or not this module is run as server."}
-    )
-
 
 class ThreadManager:
     def __init__(self, handlers):
@@ -688,83 +679,63 @@ def main():
         module_kwargs.client = True
         module_kwargs.server = True
 
-    if module_kwargs.client and module_kwargs.server and (socket_receiver_kwargs.recv_host != "127.0.0.1" or socket_sender_kwargs.recv_host != "localhost"):
-        raise ValueError()
+    stop_event = Event()
+    should_listen = Event()
+    recv_audio_chunks_queue = Queue()
+    send_audio_chunks_queue = Queue()
+    spoken_prompt_queue = Queue() 
+    text_prompt_queue = Queue()
+    llm_response_queue = Queue()
     
+    vad = VADHandler(
+        stop_event,
+        queue_in=recv_audio_chunks_queue,
+        queue_out=spoken_prompt_queue,
+        setup_args=(should_listen,),
+        setup_kwargs=vars(vad_handler_kwargs),
+    )
+    stt = WhisperSTTHandler(
+        stop_event,
+        queue_in=spoken_prompt_queue,
+        queue_out=text_prompt_queue,
+        setup_kwargs=vars(whisper_stt_handler_kwargs),
+    )
+    llm = LanguageModelHandler(
+        stop_event,
+        queue_in=text_prompt_queue,
+        queue_out=llm_response_queue,
+        setup_kwargs=vars(language_model_handler_kwargs),
+    )
+    tts = ParlerTTSHandler(
+        stop_event,
+        queue_in=llm_response_queue,
+        queue_out=send_audio_chunks_queue,
+        setup_args=(should_listen,),
+        setup_kwargs=vars(parler_tts_handler_kwargs),
+    )  
+
+    recv_handler = SocketReceiver(
+        stop_event, 
+        recv_audio_chunks_queue, 
+        should_listen,
+        host=socket_receiver_kwargs.recv_host,
+        port=socket_receiver_kwargs.recv_port,
+        chunk_size=socket_receiver_kwargs.chunk_size,
+    )
+
+    send_handler = SocketSender(
+        stop_event, 
+        send_audio_chunks_queue,
+        host=socket_sender_kwargs.send_host,
+        port=socket_sender_kwargs.send_port,
+        )
+
     try:
-        if module_kwargs.server:
-            stop_event = Event()
-            should_listen = Event()
-            recv_audio_chunks_queue = Queue()
-            send_audio_chunks_queue = Queue()
-            spoken_prompt_queue = Queue() 
-            text_prompt_queue = Queue()
-            llm_response_queue = Queue()
-            
-            vad = VADHandler(
-                stop_event,
-                queue_in=recv_audio_chunks_queue,
-                queue_out=spoken_prompt_queue,
-                setup_args=(should_listen,),
-                setup_kwargs=vars(vad_handler_kwargs),
-            )
-            stt = WhisperSTTHandler(
-                stop_event,
-                queue_in=spoken_prompt_queue,
-                queue_out=text_prompt_queue,
-                setup_kwargs=vars(whisper_stt_handler_kwargs),
-            )
-            llm = LanguageModelHandler(
-                stop_event,
-                queue_in=text_prompt_queue,
-                queue_out=llm_response_queue,
-                setup_kwargs=vars(language_model_handler_kwargs),
-            )
-            tts = ParlerTTSHandler(
-                stop_event,
-                queue_in=llm_response_queue,
-                queue_out=send_audio_chunks_queue,
-                setup_args=(should_listen,),
-                setup_kwargs=vars(parler_tts_handler_kwargs),
-            )  
-
-            recv_handler = SocketReceiver(
-                stop_event, 
-                recv_audio_chunks_queue, 
-                should_listen,
-                host=socket_receiver_kwargs.recv_host,
-                port=socket_receiver_kwargs.recv_port,
-                chunk_size=socket_receiver_kwargs.chunk_size,
-            )
-
-            send_handler = SocketSender(
-                stop_event, 
-                send_audio_chunks_queue,
-                host=socket_sender_kwargs.send_host,
-                port=socket_sender_kwargs.send_port,
-            )
-
-            pipeline_manager = ThreadManager([vad, tts, llm, stt, recv_handler, send_handler])
-            pipeline_manager.start()
-
-        if module_kwargs.client:
-            from listen_and_play import listen_and_play
-            
-            kwargs = {
-                "host": "localhost",
-                "send_port": socket_sender_kwargs.send_port,
-                "recv_port": socket_receiver_kwargs.recv_port,
-            }
-
-            listen_and_play_process = multiprocessing.Process(
-                target=listen_and_play, 
-                kwargs=kwargs,
-            )
-            listen_and_play_process.start()
+        pipeline_manager = ThreadManager([vad, tts, llm, stt, recv_handler, send_handler])
+        pipeline_manager.start()
 
     except KeyboardInterrupt:
         pipeline_manager.stop()
-        listen_and_play_process.join()
     
 if __name__ == "__main__":
     main()
-- 
GitLab