Skip to content
Snippets Groups Projects
Commit 8ca9df23 authored by wuhongsheng's avatar wuhongsheng
Browse files

feat:add paraformer_zh asr

parent d3d25c45
Branches
No related tags found
No related merge requests found
import logging
from time import perf_counter
from tensorstore import dtype
from baseHandler import BaseHandler
from funasr import AutoModel
import numpy as np
from rich.console import Console
import torch
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class ParaformerSTTHandler(BaseHandler):
"""
Handles the Speech To Text generation using a Whisper model.
"""
def setup(
self,
model_name="paraformer-zh",
device="cuda",
torch_dtype="float32",
compile_mode=None,
gen_kwargs={},
):
print(model_name)
if len(model_name.split("/")) > 1:
model_name = model_name.split("/")[-1]
self.device = device
self.model = AutoModel(model=model_name)
self.warmup()
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1
dummy_input = np.array([0] * 512,dtype=np.float32)
for _ in range(n_steps):
_ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ","")
def process(self, spoken_prompt):
logger.debug("infering paraformer...")
global pipeline_start
pipeline_start = perf_counter()
pred_text = self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ","")
torch.mps.empty_cache()
logger.debug("finished paraformer inference")
console.print(f"[yellow]USER: {pred_text}")
yield pred_text
......@@ -243,6 +243,14 @@ def main():
queue_out=text_prompt_queue,
setup_kwargs=vars(whisper_stt_handler_kwargs),
)
elif module_kwargs.stt == "paraformer":
from STT.paraformer_handler import ParaformerSTTHandler
stt = ParaformerSTTHandler(
stop_event,
queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue,
# setup_kwargs=vars(whisper_stt_handler_kwargs),
)
else:
raise ValueError("The STT should be either whisper or whisper-mlx")
if module_kwargs.llm == "transformers":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment