diff --git a/STT/paraformer_handler.py b/STT/paraformer_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..99fd6ac7912cc326472a31541a4ffdd9d8d79649 --- /dev/null +++ b/STT/paraformer_handler.py @@ -0,0 +1,61 @@ +import logging +from time import perf_counter + +from baseHandler import BaseHandler +from funasr import AutoModel +import numpy as np +from rich.console import Console +import torch + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + + +class ParaformerSTTHandler(BaseHandler): + """ + Handles the Speech To Text generation using a Paraformer model. + The default for this model is set to Chinese. + This model was contributed by @wuhongsheng. + """ + + def setup( + self, + model_name="paraformer-zh", + device="cuda", + gen_kwargs={}, + ): + print(model_name) + if len(model_name.split("/")) > 1: + model_name = model_name.split("/")[-1] + self.device = device + self.model = AutoModel(model=model_name, device=device) + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + + # 2 warmup steps for no compile or compile mode with CUDA graphs capture + n_steps = 1 + dummy_input = np.array([0] * 512, dtype=np.float32) + for _ in range(n_steps): + _ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ", "") + + def process(self, spoken_prompt): + logger.debug("infering paraformer...") + + global pipeline_start + pipeline_start = perf_counter() + + pred_text = ( + self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ", "") + ) + torch.mps.empty_cache() + + logger.debug("finished paraformer inference") + console.print(f"[yellow]USER: {pred_text}") + + yield pred_text diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index ea37b8ba2aa86cac52534a7e843d9d2705270cfd..1470bfbdbf0a307d4128a556190d3d95e84fd7e7 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -66,8 +66,9 @@ class WhisperSTTHandler(BaseHandler): if self.compile_mode not in (None, "default"): # generating more tokens than previously will trigger CUDA graphs capture # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation + # hence, having min_new_tokens < max_new_tokens in the future doesn't make sense warmup_gen_kwargs = { - "min_new_tokens": self.gen_kwargs["min_new_tokens"], + "min_new_tokens": self.gen_kwargs["max_new_tokens"], # Yes, assign max_new_tokens to min_new_tokens "max_new_tokens": self.gen_kwargs["max_new_tokens"], **self.gen_kwargs, } diff --git a/arguments_classes/module_arguments.py b/arguments_classes/module_arguments.py index 140559641dcb9d908e1e426c89eb0815af318241..df9d94286965d23a75e2f71d5594f99aeb9148fe 100644 --- a/arguments_classes/module_arguments.py +++ b/arguments_classes/module_arguments.py @@ -23,7 +23,7 @@ class ModuleArguments: stt: Optional[str] = field( default="whisper", metadata={ - "help": "The STT to use. Either 'whisper' or 'whisper-mlx'. Default is 'whisper'." + "help": "The STT to use. Either 'whisper', 'whisper-mlx', and 'paraformer'. Default is 'whisper'." }, ) llm: Optional[str] = field( diff --git a/arguments_classes/paraformer_stt_arguments.py b/arguments_classes/paraformer_stt_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..a57a66abfbc6eb1f95868364a69e106919299032 --- /dev/null +++ b/arguments_classes/paraformer_stt_arguments.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass, field + + +@dataclass +class ParaformerSTTHandlerArguments: + paraformer_stt_model_name: str = field( + default="paraformer-zh", + metadata={ + "help": "The pretrained model to use. Default is 'paraformer-zh'. Can be choose from https://github.com/modelscope/FunASR" + }, + ) + paraformer_stt_device: str = field( + default="cuda", + metadata={ + "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." + }, + ) diff --git a/arguments_classes/whisper_stt_arguments.py b/arguments_classes/whisper_stt_arguments.py index 2edb4c24e7d75bd8a204edf1a82a8dd79b0df457..bed382dda754da36965b4d86e68a7f8b4d9c322c 100644 --- a/arguments_classes/whisper_stt_arguments.py +++ b/arguments_classes/whisper_stt_arguments.py @@ -33,12 +33,6 @@ class WhisperSTTHandlerArguments: "help": "The maximum number of new tokens to generate. Default is 128." }, ) - stt_gen_min_new_tokens: int = field( - default=0, - metadata={ - "help": "The minimum number of new tokens to generate. Default is 0." - }, - ) stt_gen_num_beams: int = field( default=1, metadata={ diff --git a/requirements.txt b/requirements.txt index 000feea7f1148a517872b83b14f5219a63f2a4ab..4acd623c08f850a6504b902c0263fc382d086f2b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,6 @@ parler_tts @ git+https://github.com/huggingface/parler-tts.git melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers torch==2.4.0 sounddevice==0.5.0 -ChatTTS \ No newline at end of file +ChatTTS +funasr +modelscope \ No newline at end of file diff --git a/requirements_mac.txt b/requirements_mac.txt index 3dbbd8af40fdb14c7ca2ae826d6056f2c4d7556b..5e26fd46339d0dbcd8702954a0dc9a4be8cdec77 100644 --- a/requirements_mac.txt +++ b/requirements_mac.txt @@ -5,4 +5,6 @@ torch==2.4.0 sounddevice==0.5.0 lightning-whisper-mlx>=0.0.10 mlx-lm>=0.14.0 -ChatTTS \ No newline at end of file +ChatTTS +funasr>=1.1.6 +modelscope>=1.17.1 diff --git a/s2s_pipeline.py b/s2s_pipeline.py index dbd13193cb7fc386bab51042d88837e7156380a2..4138971e247f7c80e2ca4bd892cce9689bcd1ea9 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -14,6 +14,7 @@ from arguments_classes.mlx_language_model_arguments import ( MLXLanguageModelHandlerArguments, ) from arguments_classes.module_arguments import ModuleArguments +from arguments_classes.paraformer_stt_arguments import ParaformerSTTHandlerArguments from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments from arguments_classes.socket_receiver_arguments import SocketReceiverArguments from arguments_classes.socket_sender_arguments import SocketSenderArguments @@ -74,6 +75,7 @@ def main(): SocketSenderArguments, VADHandlerArguments, WhisperSTTHandlerArguments, + ParaformerSTTHandlerArguments, LanguageModelHandlerArguments, MLXLanguageModelHandlerArguments, ParlerTTSHandlerArguments, @@ -91,6 +93,7 @@ def main(): socket_sender_kwargs, vad_handler_kwargs, whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, @@ -105,6 +108,7 @@ def main(): socket_sender_kwargs, vad_handler_kwargs, whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, @@ -167,6 +171,8 @@ def main(): kwargs.tts_device = common_device if hasattr(kwargs, "stt_device"): kwargs.stt_device = common_device + if hasattr(kwargs, "paraformer_stt_device"): + kwargs.paraformer_stt_device = common_device # Call this function with the common device and all the handlers overwrite_device_argument( @@ -175,9 +181,11 @@ def main(): mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, whisper_stt_handler_kwargs, + paraformer_stt_handler_kwargs, ) prepare_args(whisper_stt_handler_kwargs, "stt") + prepare_args(paraformer_stt_handler_kwargs, "paraformer_stt") prepare_args(language_model_handler_kwargs, "lm") prepare_args(mlx_language_model_handler_kwargs, "mlx_lm") prepare_args(parler_tts_handler_kwargs, "tts") @@ -248,8 +256,19 @@ def main(): queue_out=text_prompt_queue, setup_kwargs=vars(whisper_stt_handler_kwargs), ) + elif module_kwargs.stt == "paraformer": + from STT.paraformer_handler import ParaformerSTTHandler + + stt = ParaformerSTTHandler( + stop_event, + queue_in=spoken_prompt_queue, + queue_out=text_prompt_queue, + setup_kwargs=vars(paraformer_stt_handler_kwargs), + ) else: - raise ValueError("The STT should be either whisper or whisper-mlx") + raise ValueError( + "The STT should be either whisper, whisper-mlx, or paraformer." + ) if module_kwargs.llm == "transformers": from LLM.language_model import LanguageModelHandler