Skip to content
Snippets Groups Projects
paraformer_handler.py 1.67 KiB
Newer Older
  • Learn to ignore specific revisions
  • wuhongsheng's avatar
    wuhongsheng committed
    import logging
    from time import perf_counter
    
    from baseHandler import BaseHandler
    from funasr import AutoModel
    import numpy as np
    from rich.console import Console
    import torch
    
    logging.basicConfig(
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )
    logger = logging.getLogger(__name__)
    
    console = Console()
    
    
    class ParaformerSTTHandler(BaseHandler):
        """
    
        Handles the Speech To Text generation using a Paraformer model.
        The default for this model is set to Chinese.
        This model was contributed by @wuhongsheng.
    
    wuhongsheng's avatar
    wuhongsheng committed
        """
    
        def setup(
            self,
            model_name="paraformer-zh",
            device="cuda",
            gen_kwargs={},
        ):
            print(model_name)
            if len(model_name.split("/")) > 1:
                model_name = model_name.split("/")[-1]
            self.device = device
    
            self.model = AutoModel(model=model_name, device=device)
    
    wuhongsheng's avatar
    wuhongsheng committed
            self.warmup()
    
        def warmup(self):
            logger.info(f"Warming up {self.__class__.__name__}")
    
            # 2 warmup steps for no compile or compile mode with CUDA graphs capture
            n_steps = 1
    
            dummy_input = np.array([0] * 512, dtype=np.float32)
    
    wuhongsheng's avatar
    wuhongsheng committed
            for _ in range(n_steps):
    
                _ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ", "")
    
    wuhongsheng's avatar
    wuhongsheng committed
    
        def process(self, spoken_prompt):
            logger.debug("infering paraformer...")
    
            global pipeline_start
            pipeline_start = perf_counter()
    
    
            pred_text = (
                self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ", "")
            )
    
    wuhongsheng's avatar
    wuhongsheng committed
            torch.mps.empty_cache()
    
            logger.debug("finished paraformer inference")
            console.print(f"[yellow]USER: {pred_text}")
    
            yield pred_text