From a211f4d7fc4bc4d8b04cf0c243cc78fc7eb36948 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan <eulebihan@gmail.com> Date: Thu, 19 Sep 2024 19:34:28 +0200 Subject: [PATCH] update s2s_pipeline.py to match main --- s2s_pipeline.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 4c0cac1..c0a7c82 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -21,6 +21,7 @@ from arguments_classes.socket_sender_arguments import SocketSenderArguments from arguments_classes.vad_arguments import VADHandlerArguments from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments +from arguments_classes.open_api_language_model_arguments import OpenApiLanguageModelHandlerArguments import torch import nltk from rich.console import Console @@ -76,6 +77,7 @@ def parse_arguments(): WhisperSTTHandlerArguments, ParaformerSTTHandlerArguments, LanguageModelHandlerArguments, + OpenApiLanguageModelHandlerArguments, MLXLanguageModelHandlerArguments, ParlerTTSHandlerArguments, MeloTTSHandlerArguments, @@ -160,6 +162,7 @@ def prepare_all_args( whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, language_model_handler_kwargs, + open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, @@ -170,6 +173,7 @@ def prepare_all_args( whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, language_model_handler_kwargs, + open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, @@ -180,6 +184,7 @@ def prepare_all_args( rename_args(paraformer_stt_handler_kwargs, "paraformer_stt") rename_args(language_model_handler_kwargs, "lm") rename_args(mlx_language_model_handler_kwargs, "mlx_lm") + rename_args(open_api_language_model_handler_kwargs, "open_api") rename_args(parler_tts_handler_kwargs, "tts") rename_args(melo_tts_handler_kwargs, "melo") rename_args(chat_tts_handler_kwargs, "chat_tts") @@ -205,6 +210,7 @@ def build_pipeline( whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, language_model_handler_kwargs, + open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, @@ -257,7 +263,7 @@ def build_pipeline( ) stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs) - lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs) + lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs) tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs) return ThreadManager([*comms_handlers, vad, stt, lm, tts]) @@ -292,7 +298,15 @@ def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_ raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.") -def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs): +def get_llm_handler( + module_kwargs, + stop_event, + text_prompt_queue, + lm_response_queue, + language_model_handler_kwargs, + open_api_language_model_handler_kwargs, + mlx_language_model_handler_kwargs +): if module_kwargs.llm == "transformers": from LLM.language_model import LanguageModelHandler return LanguageModelHandler( @@ -301,6 +315,14 @@ def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_qu queue_out=lm_response_queue, setup_kwargs=vars(language_model_handler_kwargs), ) + elif module_kwargs.llm == "open_api": + from LLM.openai_api_language_model import OpenApiModelHandler + return OpenApiModelHandler( + stop_event, + queue_in=text_prompt_queue, + queue_out=lm_response_queue, + setup_kwargs=vars(open_api_language_model_handler_kwargs), + ) elif module_kwargs.llm == "mlx-lm": from LLM.mlx_language_model import MLXLanguageModelHandler return MLXLanguageModelHandler( @@ -364,6 +386,7 @@ def main(): whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, language_model_handler_kwargs, + open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, @@ -377,6 +400,7 @@ def main(): whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, language_model_handler_kwargs, + open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, @@ -393,6 +417,7 @@ def main(): whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, language_model_handler_kwargs, + open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, -- GitLab