Skip to content
Snippets Groups Projects
Commit a211f4d7 authored by Eustache Le Bihan's avatar Eustache Le Bihan
Browse files

update s2s_pipeline.py to match main

parent 8c7272b7
No related branches found
No related tags found
No related merge requests found
...@@ -21,6 +21,7 @@ from arguments_classes.socket_sender_arguments import SocketSenderArguments ...@@ -21,6 +21,7 @@ from arguments_classes.socket_sender_arguments import SocketSenderArguments
from arguments_classes.vad_arguments import VADHandlerArguments from arguments_classes.vad_arguments import VADHandlerArguments
from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments
from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments
from arguments_classes.open_api_language_model_arguments import OpenApiLanguageModelHandlerArguments
import torch import torch
import nltk import nltk
from rich.console import Console from rich.console import Console
...@@ -76,6 +77,7 @@ def parse_arguments(): ...@@ -76,6 +77,7 @@ def parse_arguments():
WhisperSTTHandlerArguments, WhisperSTTHandlerArguments,
ParaformerSTTHandlerArguments, ParaformerSTTHandlerArguments,
LanguageModelHandlerArguments, LanguageModelHandlerArguments,
OpenApiLanguageModelHandlerArguments,
MLXLanguageModelHandlerArguments, MLXLanguageModelHandlerArguments,
ParlerTTSHandlerArguments, ParlerTTSHandlerArguments,
MeloTTSHandlerArguments, MeloTTSHandlerArguments,
...@@ -160,6 +162,7 @@ def prepare_all_args( ...@@ -160,6 +162,7 @@ def prepare_all_args(
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
...@@ -170,6 +173,7 @@ def prepare_all_args( ...@@ -170,6 +173,7 @@ def prepare_all_args(
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
...@@ -180,6 +184,7 @@ def prepare_all_args( ...@@ -180,6 +184,7 @@ def prepare_all_args(
rename_args(paraformer_stt_handler_kwargs, "paraformer_stt") rename_args(paraformer_stt_handler_kwargs, "paraformer_stt")
rename_args(language_model_handler_kwargs, "lm") rename_args(language_model_handler_kwargs, "lm")
rename_args(mlx_language_model_handler_kwargs, "mlx_lm") rename_args(mlx_language_model_handler_kwargs, "mlx_lm")
rename_args(open_api_language_model_handler_kwargs, "open_api")
rename_args(parler_tts_handler_kwargs, "tts") rename_args(parler_tts_handler_kwargs, "tts")
rename_args(melo_tts_handler_kwargs, "melo") rename_args(melo_tts_handler_kwargs, "melo")
rename_args(chat_tts_handler_kwargs, "chat_tts") rename_args(chat_tts_handler_kwargs, "chat_tts")
...@@ -205,6 +210,7 @@ def build_pipeline( ...@@ -205,6 +210,7 @@ def build_pipeline(
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
...@@ -257,7 +263,7 @@ def build_pipeline( ...@@ -257,7 +263,7 @@ def build_pipeline(
) )
stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs) stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs)
lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs) lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs)
tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs) tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs)
return ThreadManager([*comms_handlers, vad, stt, lm, tts]) return ThreadManager([*comms_handlers, vad, stt, lm, tts])
...@@ -292,7 +298,15 @@ def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_ ...@@ -292,7 +298,15 @@ def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_
raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.") raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.")
def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs): def get_llm_handler(
module_kwargs,
stop_event,
text_prompt_queue,
lm_response_queue,
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs
):
if module_kwargs.llm == "transformers": if module_kwargs.llm == "transformers":
from LLM.language_model import LanguageModelHandler from LLM.language_model import LanguageModelHandler
return LanguageModelHandler( return LanguageModelHandler(
...@@ -301,6 +315,14 @@ def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_qu ...@@ -301,6 +315,14 @@ def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_qu
queue_out=lm_response_queue, queue_out=lm_response_queue,
setup_kwargs=vars(language_model_handler_kwargs), setup_kwargs=vars(language_model_handler_kwargs),
) )
elif module_kwargs.llm == "open_api":
from LLM.openai_api_language_model import OpenApiModelHandler
return OpenApiModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
setup_kwargs=vars(open_api_language_model_handler_kwargs),
)
elif module_kwargs.llm == "mlx-lm": elif module_kwargs.llm == "mlx-lm":
from LLM.mlx_language_model import MLXLanguageModelHandler from LLM.mlx_language_model import MLXLanguageModelHandler
return MLXLanguageModelHandler( return MLXLanguageModelHandler(
...@@ -364,6 +386,7 @@ def main(): ...@@ -364,6 +386,7 @@ def main():
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
...@@ -377,6 +400,7 @@ def main(): ...@@ -377,6 +400,7 @@ def main():
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
...@@ -393,6 +417,7 @@ def main(): ...@@ -393,6 +417,7 @@ def main():
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment