From a211f4d7fc4bc4d8b04cf0c243cc78fc7eb36948 Mon Sep 17 00:00:00 2001
From: Eustache Le Bihan <eulebihan@gmail.com>
Date: Thu, 19 Sep 2024 19:34:28 +0200
Subject: [PATCH] update s2s_pipeline.py to match main

---
 s2s_pipeline.py | 29 +++++++++++++++++++++++++++--
 1 file changed, 27 insertions(+), 2 deletions(-)

diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index 4c0cac1..c0a7c82 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -21,6 +21,7 @@ from arguments_classes.socket_sender_arguments import SocketSenderArguments
 from arguments_classes.vad_arguments import VADHandlerArguments
 from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments
 from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments
+from arguments_classes.open_api_language_model_arguments import OpenApiLanguageModelHandlerArguments
 import torch
 import nltk
 from rich.console import Console
@@ -76,6 +77,7 @@ def parse_arguments():
             WhisperSTTHandlerArguments,
             ParaformerSTTHandlerArguments,
             LanguageModelHandlerArguments,
+            OpenApiLanguageModelHandlerArguments,
             MLXLanguageModelHandlerArguments,
             ParlerTTSHandlerArguments,
             MeloTTSHandlerArguments,
@@ -160,6 +162,7 @@ def prepare_all_args(
     whisper_stt_handler_kwargs,
     paraformer_stt_handler_kwargs,
     language_model_handler_kwargs,
+    open_api_language_model_handler_kwargs,
     mlx_language_model_handler_kwargs,
     parler_tts_handler_kwargs,
     melo_tts_handler_kwargs,
@@ -170,6 +173,7 @@ def prepare_all_args(
         whisper_stt_handler_kwargs,
         paraformer_stt_handler_kwargs,
         language_model_handler_kwargs,
+        open_api_language_model_handler_kwargs,
         mlx_language_model_handler_kwargs,
         parler_tts_handler_kwargs,
         melo_tts_handler_kwargs,
@@ -180,6 +184,7 @@ def prepare_all_args(
     rename_args(paraformer_stt_handler_kwargs, "paraformer_stt")
     rename_args(language_model_handler_kwargs, "lm")
     rename_args(mlx_language_model_handler_kwargs, "mlx_lm")
+    rename_args(open_api_language_model_handler_kwargs, "open_api")
     rename_args(parler_tts_handler_kwargs, "tts")
     rename_args(melo_tts_handler_kwargs, "melo")
     rename_args(chat_tts_handler_kwargs, "chat_tts")
@@ -205,6 +210,7 @@ def build_pipeline(
     whisper_stt_handler_kwargs,
     paraformer_stt_handler_kwargs,
     language_model_handler_kwargs,
+    open_api_language_model_handler_kwargs,
     mlx_language_model_handler_kwargs,
     parler_tts_handler_kwargs,
     melo_tts_handler_kwargs,
@@ -257,7 +263,7 @@ def build_pipeline(
     )
 
     stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs)
-    lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs)
+    lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs)
     tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs)
 
     return ThreadManager([*comms_handlers, vad, stt, lm, tts])
@@ -292,7 +298,15 @@ def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_
         raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.")
 
 
-def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs):
+def get_llm_handler(
+    module_kwargs, 
+    stop_event, 
+    text_prompt_queue, 
+    lm_response_queue, 
+    language_model_handler_kwargs,
+    open_api_language_model_handler_kwargs,
+    mlx_language_model_handler_kwargs
+):
     if module_kwargs.llm == "transformers":
         from LLM.language_model import LanguageModelHandler
         return LanguageModelHandler(
@@ -301,6 +315,14 @@ def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_qu
             queue_out=lm_response_queue,
             setup_kwargs=vars(language_model_handler_kwargs),
         )
+    elif module_kwargs.llm == "open_api":
+        from LLM.openai_api_language_model import OpenApiModelHandler
+        return OpenApiModelHandler(
+            stop_event,
+            queue_in=text_prompt_queue,
+            queue_out=lm_response_queue,
+            setup_kwargs=vars(open_api_language_model_handler_kwargs),
+        )
     elif module_kwargs.llm == "mlx-lm":
         from LLM.mlx_language_model import MLXLanguageModelHandler
         return MLXLanguageModelHandler(
@@ -364,6 +386,7 @@ def main():
         whisper_stt_handler_kwargs,
         paraformer_stt_handler_kwargs,
         language_model_handler_kwargs,
+        open_api_language_model_handler_kwargs,
         mlx_language_model_handler_kwargs,
         parler_tts_handler_kwargs,
         melo_tts_handler_kwargs,
@@ -377,6 +400,7 @@ def main():
         whisper_stt_handler_kwargs,
         paraformer_stt_handler_kwargs,
         language_model_handler_kwargs,
+        open_api_language_model_handler_kwargs,
         mlx_language_model_handler_kwargs,
         parler_tts_handler_kwargs,
         melo_tts_handler_kwargs,
@@ -393,6 +417,7 @@ def main():
         whisper_stt_handler_kwargs,
         paraformer_stt_handler_kwargs,
         language_model_handler_kwargs,
+        open_api_language_model_handler_kwargs,
         mlx_language_model_handler_kwargs,
         parler_tts_handler_kwargs,
         melo_tts_handler_kwargs,
-- 
GitLab