Skip to content
Snippets Groups Projects
Commit 9941f6f3 authored by Andres Marafioti's avatar Andres Marafioti
Browse files

refactor s2s_pipeline

parent 0ae1b01d
No related branches found
No related tags found
No related merge requests found
......@@ -67,7 +67,7 @@ def prepare_args(args, prefix):
args.__dict__["gen_kwargs"] = gen_kwargs
def main():
def parse_arguments():
parser = HfArgumentParser(
(
ModuleArguments,
......@@ -84,69 +84,40 @@ def main():
)
)
# 0. Parse CLI arguments
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# Parse configurations from a JSON file if specified
(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
return parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
# Parse arguments from command line if no JSON file is provided
(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parser.parse_args_into_dataclasses()
# 1. Handle logger
return parser.parse_args_into_dataclasses()
def setup_logger(log_level):
global logger
logging.basicConfig(
level=module_kwargs.log_level.upper(),
level=log_level.upper(),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# torch compile logs
if module_kwargs.log_level == "debug":
if log_level == "debug":
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs):
if mac_optimal_settings:
for kwargs in handler_kwargs:
if hasattr(kwargs, "device"):
kwargs.device = "mps"
if hasattr(kwargs, "mode"):
kwargs.mode = "local"
if hasattr(kwargs, "stt"):
kwargs.stt = "whisper-mlx"
if hasattr(kwargs, "llm"):
kwargs.llm = "mlx-lm"
if hasattr(kwargs, "tts"):
kwargs.tts = "melo"
optimal_mac_settings(
module_kwargs.local_mac_optimal_settings,
module_kwargs,
)
def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs):
if mac_optimal_settings:
for kwargs in handler_kwargs:
if hasattr(kwargs, "device"):
kwargs.device = "mps"
if hasattr(kwargs, "mode"):
kwargs.mode = "local"
if hasattr(kwargs, "stt"):
kwargs.stt = "whisper-mlx"
if hasattr(kwargs, "llm"):
kwargs.llm = "mlx-lm"
if hasattr(kwargs, "tts"):
kwargs.tts = "melo"
def check_mac_settings(module_kwargs):
if platform == "darwin":
if module_kwargs.device == "cuda":
raise ValueError(
......@@ -161,29 +132,29 @@ def main():
"If you experiences issues generating the voice, considering setting the tts to melo."
)
# 2. Prepare each part's arguments
def overwrite_device_argument(common_device: Optional[str], *handler_kwargs):
if common_device:
for kwargs in handler_kwargs:
if hasattr(kwargs, "lm_device"):
kwargs.lm_device = common_device
if hasattr(kwargs, "tts_device"):
kwargs.tts_device = common_device
if hasattr(kwargs, "stt_device"):
kwargs.stt_device = common_device
if hasattr(kwargs, "paraformer_stt_device"):
kwargs.paraformer_stt_device = common_device
# Call this function with the common device and all the handlers
overwrite_device_argument(
module_kwargs.device,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
)
def overwrite_device_argument(common_device: Optional[str], *handler_kwargs):
if common_device:
for kwargs in handler_kwargs:
if hasattr(kwargs, "lm_device"):
kwargs.lm_device = common_device
if hasattr(kwargs, "tts_device"):
kwargs.tts_device = common_device
if hasattr(kwargs, "stt_device"):
kwargs.stt_device = common_device
if hasattr(kwargs, "paraformer_stt_device"):
kwargs.paraformer_stt_device = common_device
def prepare_all_args(
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
):
prepare_args(whisper_stt_handler_kwargs, "stt")
prepare_args(paraformer_stt_handler_kwargs, "paraformer_stt")
prepare_args(language_model_handler_kwargs, "lm")
......@@ -192,7 +163,20 @@ def main():
prepare_args(melo_tts_handler_kwargs, "melo")
prepare_args(chat_tts_handler_kwargs, "chat_tts")
# 3. Build the pipeline
def build_pipeline(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
):
stop_event = Event()
# used to stop putting received audio chunks in queue until all setences have been processed by the TTS
should_listen = Event()
......@@ -238,10 +222,18 @@ def main():
setup_args=(should_listen,),
setup_kwargs=vars(vad_handler_kwargs),
)
stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs)
lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs)
tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs)
return ThreadManager([*comms_handlers, vad, stt, lm, tts])
def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs):
if module_kwargs.stt == "whisper":
from STT.whisper_stt_handler import WhisperSTTHandler
stt = WhisperSTTHandler(
return WhisperSTTHandler(
stop_event,
queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue,
......@@ -249,8 +241,7 @@ def main():
)
elif module_kwargs.stt == "whisper-mlx":
from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler
stt = LightningWhisperSTTHandler(
return LightningWhisperSTTHandler(
stop_event,
queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue,
......@@ -258,21 +249,20 @@ def main():
)
elif module_kwargs.stt == "paraformer":
from STT.paraformer_handler import ParaformerSTTHandler
stt = ParaformerSTTHandler(
return ParaformerSTTHandler(
stop_event,
queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue,
setup_kwargs=vars(paraformer_stt_handler_kwargs),
)
else:
raise ValueError(
"The STT should be either whisper, whisper-mlx, or paraformer."
)
raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.")
def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs):
if module_kwargs.llm == "transformers":
from LLM.language_model import LanguageModelHandler
lm = LanguageModelHandler(
return LanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
......@@ -280,8 +270,7 @@ def main():
)
elif module_kwargs.llm == "mlx-lm":
from LLM.mlx_language_model import MLXLanguageModelHandler
lm = MLXLanguageModelHandler(
return MLXLanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
......@@ -289,10 +278,12 @@ def main():
)
else:
raise ValueError("The LLM should be either transformers or mlx-lm")
def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs):
if module_kwargs.tts == "parler":
from TTS.parler_handler import ParlerTTSHandler
tts = ParlerTTSHandler(
return ParlerTTSHandler(
stop_event,
queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue,
......@@ -307,7 +298,7 @@ def main():
"Error importing MeloTTSHandler. You might need to run: python -m unidic download"
)
raise e
tts = MeloTTSHandler(
return MeloTTSHandler(
stop_event,
queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue,
......@@ -320,7 +311,7 @@ def main():
except RuntimeError as e:
logger.error("Error importing ChatTTSHandler")
raise e
tts = ChatTTSHandler(
return ChatTTSHandler(
stop_event,
queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue,
......@@ -330,14 +321,69 @@ def main():
else:
raise ValueError("The TTS should be either parler, melo or chatTTS")
# 4. Run the pipeline
def main():
(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parse_arguments()
setup_logger(module_kwargs.log_level)
optimal_mac_settings(
module_kwargs.local_mac_optimal_settings,
module_kwargs,
)
check_mac_settings(module_kwargs)
overwrite_device_argument(
module_kwargs.device,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
)
prepare_all_args(
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
)
pipeline_manager = build_pipeline(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
)
try:
pipeline_manager = ThreadManager([*comms_handlers, vad, stt, lm, tts])
pipeline_manager.start()
except KeyboardInterrupt:
pipeline_manager.stop()
if __name__ == "__main__":
main()
main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment