Skip to content
Snippets Groups Projects
Unverified Commit 8c7272b7 authored by eustlb's avatar eustlb Committed by GitHub
Browse files

Merge pull request #106 from huggingface/refactor_for_inference

Refactor for inference
parents 0ae1b01d 0bc30b68
No related branches found
No related tags found
No related merge requests found
...@@ -49,11 +49,10 @@ console = Console() ...@@ -49,11 +49,10 @@ console = Console()
logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs
def prepare_args(args, prefix): def rename_args(args, prefix):
""" """
Rename arguments by removing the prefix and prepares the gen_kwargs. Rename arguments by removing the prefix and prepares the gen_kwargs.
""" """
gen_kwargs = {} gen_kwargs = {}
for key in copy(args.__dict__): for key in copy(args.__dict__):
if key.startswith(prefix): if key.startswith(prefix):
...@@ -67,7 +66,7 @@ def prepare_args(args, prefix): ...@@ -67,7 +66,7 @@ def prepare_args(args, prefix):
args.__dict__["gen_kwargs"] = gen_kwargs args.__dict__["gen_kwargs"] = gen_kwargs
def main(): def parse_arguments():
parser = HfArgumentParser( parser = HfArgumentParser(
( (
ModuleArguments, ModuleArguments,
...@@ -84,69 +83,43 @@ def main(): ...@@ -84,69 +83,43 @@ def main():
) )
) )
# 0. Parse CLI arguments
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# Parse configurations from a JSON file if specified # Parse configurations from a JSON file if specified
( return parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else: else:
# Parse arguments from command line if no JSON file is provided # Parse arguments from command line if no JSON file is provided
( return parser.parse_args_into_dataclasses()
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs, def setup_logger(log_level):
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parser.parse_args_into_dataclasses()
# 1. Handle logger
global logger global logger
logging.basicConfig( logging.basicConfig(
level=module_kwargs.log_level.upper(), level=log_level.upper(),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# torch compile logs # torch compile logs
if module_kwargs.log_level == "debug": if log_level == "debug":
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs):
if mac_optimal_settings:
for kwargs in handler_kwargs:
if hasattr(kwargs, "device"):
kwargs.device = "mps"
if hasattr(kwargs, "mode"):
kwargs.mode = "local"
if hasattr(kwargs, "stt"):
kwargs.stt = "whisper-mlx"
if hasattr(kwargs, "llm"):
kwargs.llm = "mlx-lm"
if hasattr(kwargs, "tts"):
kwargs.tts = "melo"
optimal_mac_settings(
module_kwargs.local_mac_optimal_settings,
module_kwargs,
)
def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs):
if mac_optimal_settings:
for kwargs in handler_kwargs:
if hasattr(kwargs, "device"):
kwargs.device = "mps"
if hasattr(kwargs, "mode"):
kwargs.mode = "local"
if hasattr(kwargs, "stt"):
kwargs.stt = "whisper-mlx"
if hasattr(kwargs, "llm"):
kwargs.llm = "mlx-lm"
if hasattr(kwargs, "tts"):
kwargs.tts = "melo"
def check_mac_settings(module_kwargs):
if platform == "darwin": if platform == "darwin":
if module_kwargs.device == "cuda": if module_kwargs.device == "cuda":
raise ValueError( raise ValueError(
...@@ -161,46 +134,90 @@ def main(): ...@@ -161,46 +134,90 @@ def main():
"If you experiences issues generating the voice, considering setting the tts to melo." "If you experiences issues generating the voice, considering setting the tts to melo."
) )
# 2. Prepare each part's arguments
def overwrite_device_argument(common_device: Optional[str], *handler_kwargs): def overwrite_device_argument(common_device: Optional[str], *handler_kwargs):
if common_device: if common_device:
for kwargs in handler_kwargs: for kwargs in handler_kwargs:
if hasattr(kwargs, "lm_device"): if hasattr(kwargs, "lm_device"):
kwargs.lm_device = common_device kwargs.lm_device = common_device
if hasattr(kwargs, "tts_device"): if hasattr(kwargs, "tts_device"):
kwargs.tts_device = common_device kwargs.tts_device = common_device
if hasattr(kwargs, "stt_device"): if hasattr(kwargs, "stt_device"):
kwargs.stt_device = common_device kwargs.stt_device = common_device
if hasattr(kwargs, "paraformer_stt_device"): if hasattr(kwargs, "paraformer_stt_device"):
kwargs.paraformer_stt_device = common_device kwargs.paraformer_stt_device = common_device
# Call this function with the common device and all the handlers
overwrite_device_argument( def prepare_module_args(module_kwargs, *handler_kwargs):
module_kwargs.device, optimal_mac_settings(module_kwargs.local_mac_optimal_settings, module_kwargs)
if platform == "darwin":
check_mac_settings(module_kwargs)
overwrite_device_argument(module_kwargs.device, *handler_kwargs)
def prepare_all_args(
module_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
):
prepare_module_args(
module_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
whisper_stt_handler_kwargs, melo_tts_handler_kwargs,
paraformer_stt_handler_kwargs, chat_tts_handler_kwargs,
) )
prepare_args(whisper_stt_handler_kwargs, "stt") rename_args(whisper_stt_handler_kwargs, "stt")
prepare_args(paraformer_stt_handler_kwargs, "paraformer_stt") rename_args(paraformer_stt_handler_kwargs, "paraformer_stt")
prepare_args(language_model_handler_kwargs, "lm") rename_args(language_model_handler_kwargs, "lm")
prepare_args(mlx_language_model_handler_kwargs, "mlx_lm") rename_args(mlx_language_model_handler_kwargs, "mlx_lm")
prepare_args(parler_tts_handler_kwargs, "tts") rename_args(parler_tts_handler_kwargs, "tts")
prepare_args(melo_tts_handler_kwargs, "melo") rename_args(melo_tts_handler_kwargs, "melo")
prepare_args(chat_tts_handler_kwargs, "chat_tts") rename_args(chat_tts_handler_kwargs, "chat_tts")
# 3. Build the pipeline
stop_event = Event() def initialize_queues_and_events():
# used to stop putting received audio chunks in queue until all setences have been processed by the TTS return {
should_listen = Event() "stop_event": Event(),
recv_audio_chunks_queue = Queue() "should_listen": Event(),
send_audio_chunks_queue = Queue() "recv_audio_chunks_queue": Queue(),
spoken_prompt_queue = Queue() "send_audio_chunks_queue": Queue(),
text_prompt_queue = Queue() "spoken_prompt_queue": Queue(),
lm_response_queue = Queue() "text_prompt_queue": Queue(),
"lm_response_queue": Queue(),
}
def build_pipeline(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
queues_and_events,
):
stop_event = queues_and_events["stop_event"]
should_listen = queues_and_events["should_listen"]
recv_audio_chunks_queue = queues_and_events["recv_audio_chunks_queue"]
send_audio_chunks_queue = queues_and_events["send_audio_chunks_queue"]
spoken_prompt_queue = queues_and_events["spoken_prompt_queue"]
text_prompt_queue = queues_and_events["text_prompt_queue"]
lm_response_queue = queues_and_events["lm_response_queue"]
if module_kwargs.mode == "local": if module_kwargs.mode == "local":
from connections.local_audio_streamer import LocalAudioStreamer from connections.local_audio_streamer import LocalAudioStreamer
...@@ -238,10 +255,18 @@ def main(): ...@@ -238,10 +255,18 @@ def main():
setup_args=(should_listen,), setup_args=(should_listen,),
setup_kwargs=vars(vad_handler_kwargs), setup_kwargs=vars(vad_handler_kwargs),
) )
stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs)
lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs)
tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs)
return ThreadManager([*comms_handlers, vad, stt, lm, tts])
def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs):
if module_kwargs.stt == "whisper": if module_kwargs.stt == "whisper":
from STT.whisper_stt_handler import WhisperSTTHandler from STT.whisper_stt_handler import WhisperSTTHandler
return WhisperSTTHandler(
stt = WhisperSTTHandler(
stop_event, stop_event,
queue_in=spoken_prompt_queue, queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue, queue_out=text_prompt_queue,
...@@ -249,8 +274,7 @@ def main(): ...@@ -249,8 +274,7 @@ def main():
) )
elif module_kwargs.stt == "whisper-mlx": elif module_kwargs.stt == "whisper-mlx":
from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler
return LightningWhisperSTTHandler(
stt = LightningWhisperSTTHandler(
stop_event, stop_event,
queue_in=spoken_prompt_queue, queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue, queue_out=text_prompt_queue,
...@@ -258,21 +282,20 @@ def main(): ...@@ -258,21 +282,20 @@ def main():
) )
elif module_kwargs.stt == "paraformer": elif module_kwargs.stt == "paraformer":
from STT.paraformer_handler import ParaformerSTTHandler from STT.paraformer_handler import ParaformerSTTHandler
return ParaformerSTTHandler(
stt = ParaformerSTTHandler(
stop_event, stop_event,
queue_in=spoken_prompt_queue, queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue, queue_out=text_prompt_queue,
setup_kwargs=vars(paraformer_stt_handler_kwargs), setup_kwargs=vars(paraformer_stt_handler_kwargs),
) )
else: else:
raise ValueError( raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.")
"The STT should be either whisper, whisper-mlx, or paraformer."
)
def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs):
if module_kwargs.llm == "transformers": if module_kwargs.llm == "transformers":
from LLM.language_model import LanguageModelHandler from LLM.language_model import LanguageModelHandler
return LanguageModelHandler(
lm = LanguageModelHandler(
stop_event, stop_event,
queue_in=text_prompt_queue, queue_in=text_prompt_queue,
queue_out=lm_response_queue, queue_out=lm_response_queue,
...@@ -280,8 +303,7 @@ def main(): ...@@ -280,8 +303,7 @@ def main():
) )
elif module_kwargs.llm == "mlx-lm": elif module_kwargs.llm == "mlx-lm":
from LLM.mlx_language_model import MLXLanguageModelHandler from LLM.mlx_language_model import MLXLanguageModelHandler
return MLXLanguageModelHandler(
lm = MLXLanguageModelHandler(
stop_event, stop_event,
queue_in=text_prompt_queue, queue_in=text_prompt_queue,
queue_out=lm_response_queue, queue_out=lm_response_queue,
...@@ -289,10 +311,12 @@ def main(): ...@@ -289,10 +311,12 @@ def main():
) )
else: else:
raise ValueError("The LLM should be either transformers or mlx-lm") raise ValueError("The LLM should be either transformers or mlx-lm")
def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs):
if module_kwargs.tts == "parler": if module_kwargs.tts == "parler":
from TTS.parler_handler import ParlerTTSHandler from TTS.parler_handler import ParlerTTSHandler
return ParlerTTSHandler(
tts = ParlerTTSHandler(
stop_event, stop_event,
queue_in=lm_response_queue, queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue, queue_out=send_audio_chunks_queue,
...@@ -307,7 +331,7 @@ def main(): ...@@ -307,7 +331,7 @@ def main():
"Error importing MeloTTSHandler. You might need to run: python -m unidic download" "Error importing MeloTTSHandler. You might need to run: python -m unidic download"
) )
raise e raise e
tts = MeloTTSHandler( return MeloTTSHandler(
stop_event, stop_event,
queue_in=lm_response_queue, queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue, queue_out=send_audio_chunks_queue,
...@@ -320,7 +344,7 @@ def main(): ...@@ -320,7 +344,7 @@ def main():
except RuntimeError as e: except RuntimeError as e:
logger.error("Error importing ChatTTSHandler") logger.error("Error importing ChatTTSHandler")
raise e raise e
tts = ChatTTSHandler( return ChatTTSHandler(
stop_event, stop_event,
queue_in=lm_response_queue, queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue, queue_out=send_audio_chunks_queue,
...@@ -330,14 +354,57 @@ def main(): ...@@ -330,14 +354,57 @@ def main():
else: else:
raise ValueError("The TTS should be either parler, melo or chatTTS") raise ValueError("The TTS should be either parler, melo or chatTTS")
# 4. Run the pipeline
def main():
(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parse_arguments()
setup_logger(module_kwargs.log_level)
prepare_all_args(
module_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
)
queues_and_events = initialize_queues_and_events()
pipeline_manager = build_pipeline(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
queues_and_events,
)
try: try:
pipeline_manager = ThreadManager([*comms_handlers, vad, stt, lm, tts])
pipeline_manager.start() pipeline_manager.start()
except KeyboardInterrupt: except KeyboardInterrupt:
pipeline_manager.stop() pipeline_manager.stop()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment