From 011d25509d32fa2ec79904e13184045b5c592998 Mon Sep 17 00:00:00 2001
From: Eustache Le Bihan <eulebihan@gmail.com>
Date: Mon, 12 Aug 2024 15:07:33 +0000
Subject: [PATCH] handle local run

---
 listen_and_play.py | 18 +++++++++---------
 s2s_pipeline.py    | 14 ++++++++------
 2 files changed, 17 insertions(+), 15 deletions(-)

diff --git a/listen_and_play.py b/listen_and_play.py
index cb18fc1..398cb32 100644
--- a/listen_and_play.py
+++ b/listen_and_play.py
@@ -3,6 +3,7 @@ import numpy as np
 import threading
 from queue import Queue
 from dataclasses import dataclass, field
+import sounddevice as sd
 from transformers import HfArgumentParser
 
 
@@ -26,19 +27,19 @@ class ListenAndPlayArguments:
             "help": "The size of data chunks (in bytes). Default is 1024."
         }
     )
-    listen_play_host: str = field(
+    host: str = field(
         default="localhost",
         metadata={
             "help": "The hostname or IP address for listening and playing. Default is 'localhost'."
         }
     )
-    listen_play_send_port: int = field(
+    send_port: int = field(
         default=12345,
         metadata={
             "help": "The network port for sending data. Default is 12345."
         }
     )
-    listen_play_recv_port: int = field(
+    recv_port: int = field(
         default=12346,
         metadata={
             "help": "The network port for receiving data. Default is 12346."
@@ -50,17 +51,16 @@ def listen_and_play(
     send_rate=16000,
     recv_rate=44100,
     list_play_chunk_size=1024,
-    listen_play_host="localhost",
-    listen_play_send_port=12345,
-    listen_play_recv_port=12346,
+    host="localhost",
+    send_port=12345,
+    recv_port=12346,
 ):
-    import sounddevice as sd
   
     send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    send_socket.connect((listen_play_host, listen_play_send_port))
+    send_socket.connect((host, send_port))
 
     recv_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    recv_socket.connect((listen_play_host, listen_play_recv_port))
+    recv_socket.connect((host, recv_port))
 
     print("Recording and streaming...")
 
diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index 25b6f30..f89a550 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -34,11 +34,10 @@ from utils import (
     int2float,
 )
 
-from listen_and_play import ListenAndPlayArguments
-
 
 console = Console()
 
+
 @dataclass
 class ModuleArguments:
     log_level: str = field(
@@ -647,7 +646,6 @@ def main():
         WhisperSTTHandlerArguments,
         LanguageModelHandlerArguments,
         ParlerTTSHandlerArguments,
-        ListenAndPlayArguments
     ))
 
     # 0. Parse CLI arguments
@@ -661,7 +659,6 @@ def main():
             whisper_stt_handler_kwargs, 
             language_model_handler_kwargs, 
             parler_tts_handler_kwargs,
-            listen_and_play_kwargs,
         ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
     else:
         # Parse arguments from command line if no JSON file is provided
@@ -673,7 +670,6 @@ def main():
             whisper_stt_handler_kwargs, 
             language_model_handler_kwargs, 
             parler_tts_handler_kwargs,
-            listen_and_play_kwargs
         ) = parser.parse_args_into_dataclasses()
 
     global logger
@@ -753,10 +749,16 @@ def main():
 
         if module_kwargs.client:
             from listen_and_play import listen_and_play
+            
+            kwargs = {
+                "host": "localhost",
+                "send_port": socket_sender_kwargs.send_port,
+                "recv_port": socket_receiver_kwargs.recv_port,
+            }
 
             listen_and_play_process = multiprocessing.Process(
                 target=listen_and_play, 
-                kwargs=vars(listen_and_play_kwargs),
+                kwargs=kwargs,
             )
             listen_and_play_process.start()
 
-- 
GitLab