From 8ca9df23ade4ee292083c0b641f09ac482fc137c Mon Sep 17 00:00:00 2001
From: wuhongsheng <664116298@qq.com>
Date: Mon, 26 Aug 2024 22:33:10 +0800
Subject: [PATCH] feat:add paraformer_zh asr

---
 STT/paraformer_handler.py | 61 +++++++++++++++++++++++++++++++++++++++
 s2s_pipeline.py           |  8 +++++
 2 files changed, 69 insertions(+)
 create mode 100644 STT/paraformer_handler.py

diff --git a/STT/paraformer_handler.py b/STT/paraformer_handler.py
new file mode 100644
index 0000000..0a2a9c0
--- /dev/null
+++ b/STT/paraformer_handler.py
@@ -0,0 +1,61 @@
+import logging
+from time import perf_counter
+
+from tensorstore import dtype
+
+from baseHandler import BaseHandler
+from funasr import AutoModel
+import numpy as np
+from rich.console import Console
+import torch
+
+logging.basicConfig(
+    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+console = Console()
+
+
+class ParaformerSTTHandler(BaseHandler):
+    """
+    Handles the Speech To Text generation using a Whisper model.
+    """
+
+    def setup(
+        self,
+        model_name="paraformer-zh",
+        device="cuda",
+        torch_dtype="float32",
+        compile_mode=None,
+        gen_kwargs={},
+    ):
+        print(model_name)
+        if len(model_name.split("/")) > 1:
+            model_name = model_name.split("/")[-1]
+        self.device = device
+        self.model = AutoModel(model=model_name)
+        self.warmup()
+
+    def warmup(self):
+        logger.info(f"Warming up {self.__class__.__name__}")
+
+        # 2 warmup steps for no compile or compile mode with CUDA graphs capture
+        n_steps = 1
+        dummy_input = np.array([0] * 512,dtype=np.float32)
+        for _ in range(n_steps):
+            _ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ","")
+
+    def process(self, spoken_prompt):
+        logger.debug("infering paraformer...")
+
+        global pipeline_start
+        pipeline_start = perf_counter()
+
+        pred_text = self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ","")
+        torch.mps.empty_cache()
+
+        logger.debug("finished paraformer inference")
+        console.print(f"[yellow]USER: {pred_text}")
+
+        yield pred_text
diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index 7a9e204..002cde9 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -243,6 +243,14 @@ def main():
             queue_out=text_prompt_queue,
             setup_kwargs=vars(whisper_stt_handler_kwargs),
         )
+    elif module_kwargs.stt == "paraformer":
+        from STT.paraformer_handler import ParaformerSTTHandler
+        stt = ParaformerSTTHandler(
+            stop_event,
+            queue_in=spoken_prompt_queue,
+            queue_out=text_prompt_queue,
+            # setup_kwargs=vars(whisper_stt_handler_kwargs),
+        )
     else:
         raise ValueError("The STT should be either whisper or whisper-mlx")
     if module_kwargs.llm == "transformers":
-- 
GitLab