diff --git a/STT/paraformer_handler.py b/STT/paraformer_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2a9c002716b9acd1b3ddf1c52e847aa571ef2e --- /dev/null +++ b/STT/paraformer_handler.py @@ -0,0 +1,61 @@ +import logging +from time import perf_counter + +from tensorstore import dtype + +from baseHandler import BaseHandler +from funasr import AutoModel +import numpy as np +from rich.console import Console +import torch + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + + +class ParaformerSTTHandler(BaseHandler): + """ + Handles the Speech To Text generation using a Whisper model. + """ + + def setup( + self, + model_name="paraformer-zh", + device="cuda", + torch_dtype="float32", + compile_mode=None, + gen_kwargs={}, + ): + print(model_name) + if len(model_name.split("/")) > 1: + model_name = model_name.split("/")[-1] + self.device = device + self.model = AutoModel(model=model_name) + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + + # 2 warmup steps for no compile or compile mode with CUDA graphs capture + n_steps = 1 + dummy_input = np.array([0] * 512,dtype=np.float32) + for _ in range(n_steps): + _ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ","") + + def process(self, spoken_prompt): + logger.debug("infering paraformer...") + + global pipeline_start + pipeline_start = perf_counter() + + pred_text = self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ","") + torch.mps.empty_cache() + + logger.debug("finished paraformer inference") + console.print(f"[yellow]USER: {pred_text}") + + yield pred_text diff --git a/requirements.txt b/requirements.txt index b4a5a0e36820f6b4a6900ada8d3a4b8a6059f669..3faeb785953273e22b1dcc5ed575bdf2c444dfa1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,6 @@ nltk==3.9.1 parler_tts @ git+https://github.com/huggingface/parler-tts.git melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers torch==2.4.0 -sounddevice==0.5.0 \ No newline at end of file +sounddevice==0.5.0 +funasr +modelscope \ No newline at end of file diff --git a/requirements_mac.txt b/requirements_mac.txt index 24ba6434d116e37e88030d958ada90f56003a65e..e1e864a7da83171ac47b73175c5317a9614493bc 100644 --- a/requirements_mac.txt +++ b/requirements_mac.txt @@ -4,4 +4,6 @@ melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a torch==2.4.0 sounddevice==0.5.0 lightning-whisper-mlx>=0.0.10 -mlx-lm>=0.14.0 \ No newline at end of file +mlx-lm>=0.14.0 +funasr +modelscope \ No newline at end of file diff --git a/s2s_pipeline.py b/s2s_pipeline.py index c5d8e133be2d493eb6754f7532a5c779de31333c..a5e6aaa8cc3c142ca4838203ce417a2ffac90c41 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -243,6 +243,14 @@ def main(): queue_out=text_prompt_queue, setup_kwargs=vars(whisper_stt_handler_kwargs), ) + elif module_kwargs.stt == "paraformer": + from STT.paraformer_handler import ParaformerSTTHandler + stt = ParaformerSTTHandler( + stop_event, + queue_in=spoken_prompt_queue, + queue_out=text_prompt_queue, + # setup_kwargs=vars(whisper_stt_handler_kwargs), + ) else: raise ValueError("The STT should be either whisper or whisper-mlx") if module_kwargs.llm == "transformers":