diff --git a/s2s_pipeline.py b/s2s_pipeline.py index fc3d433b6ec1feefd53b7289024ab2caa9280891..4c0cac1d4658bf0b44505f44664f63b5e275e4ff 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -49,11 +49,10 @@ console = Console() logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs -def prepare_args(args, prefix): +def rename_args(args, prefix): """ Rename arguments by removing the prefix and prepares the gen_kwargs. """ - gen_kwargs = {} for key in copy(args.__dict__): if key.startswith(prefix): @@ -149,7 +148,15 @@ def overwrite_device_argument(common_device: Optional[str], *handler_kwargs): kwargs.paraformer_stt_device = common_device +def prepare_module_args(module_kwargs, *handler_kwargs): + optimal_mac_settings(module_kwargs.local_mac_optimal_settings, module_kwargs) + if platform == "darwin": + check_mac_settings(module_kwargs) + overwrite_device_argument(module_kwargs.device, *handler_kwargs) + + def prepare_all_args( + module_kwargs, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, language_model_handler_kwargs, @@ -158,13 +165,24 @@ def prepare_all_args( melo_tts_handler_kwargs, chat_tts_handler_kwargs, ): - prepare_args(whisper_stt_handler_kwargs, "stt") - prepare_args(paraformer_stt_handler_kwargs, "paraformer_stt") - prepare_args(language_model_handler_kwargs, "lm") - prepare_args(mlx_language_model_handler_kwargs, "mlx_lm") - prepare_args(parler_tts_handler_kwargs, "tts") - prepare_args(melo_tts_handler_kwargs, "melo") - prepare_args(chat_tts_handler_kwargs, "chat_tts") + prepare_module_args( + module_kwargs, + whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, + language_model_handler_kwargs, + mlx_language_model_handler_kwargs, + parler_tts_handler_kwargs, + melo_tts_handler_kwargs, + chat_tts_handler_kwargs, + ) + + rename_args(whisper_stt_handler_kwargs, "stt") + rename_args(paraformer_stt_handler_kwargs, "paraformer_stt") + rename_args(language_model_handler_kwargs, "lm") + rename_args(mlx_language_model_handler_kwargs, "mlx_lm") + rename_args(parler_tts_handler_kwargs, "tts") + rename_args(melo_tts_handler_kwargs, "melo") + rename_args(chat_tts_handler_kwargs, "chat_tts") def initialize_queues_and_events(): @@ -354,23 +372,8 @@ def main(): setup_logger(module_kwargs.log_level) - optimal_mac_settings( - module_kwargs.local_mac_optimal_settings, - module_kwargs, - ) - - check_mac_settings(module_kwargs) - - overwrite_device_argument( - module_kwargs.device, - language_model_handler_kwargs, - mlx_language_model_handler_kwargs, - parler_tts_handler_kwargs, - whisper_stt_handler_kwargs, - paraformer_stt_handler_kwargs, - ) - prepare_all_args( + module_kwargs, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, language_model_handler_kwargs,