Skip to content
Snippets Groups Projects
lightning_whisper_mlx_handler.py 2.74 KiB
Newer Older
  • Learn to ignore specific revisions
  • Andres Marafioti's avatar
    Andres Marafioti committed
    import logging
    from time import perf_counter
    from baseHandler import BaseHandler
    from lightning_whisper_mlx import LightningWhisperMLX
    import numpy as np
    from rich.console import Console
    
    from copy import copy
    
    Andres Marafioti's avatar
    Andres Marafioti committed
    import torch
    
    Andres Marafioti's avatar
    Andres Marafioti committed
    logger = logging.getLogger(__name__)
    
    console = Console()
    
    
    SUPPORTED_LANGUAGES = [
        "en",
        "fr",
        "es",
        "zh",
        "ja",
        "ko",
    
    Andres Marafioti's avatar
    Andres Marafioti committed
        "hi",
        "de",
        "pt",
    
    Andres Marafioti's avatar
    Andres Marafioti committed
        "pl",
    
    Andres Marafioti's avatar
    Andres Marafioti committed
        "it",
    
    Andres Marafioti's avatar
    Andres Marafioti committed
        "nl",
    
    Andres Marafioti's avatar
    Andres Marafioti committed
    
    class LightningWhisperSTTHandler(BaseHandler):
        """
        Handles the Speech To Text generation using a Whisper model.
        """
    
        def setup(
            self,
    
    Andres Marafioti's avatar
    Andres Marafioti committed
            model_name="distil-large-v3",
    
            device="mps",
    
    Andres Marafioti's avatar
    Andres Marafioti committed
            torch_dtype="float16",
            compile_mode=None,
    
            language=None,
    
    Andres Marafioti's avatar
    Andres Marafioti committed
            gen_kwargs={},
        ):
    
            if len(model_name.split("/")) > 1:
                model_name = model_name.split("/")[-1]
    
    Andres Marafioti's avatar
    Andres Marafioti committed
            self.device = device
    
            self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
    
            self.start_language = language
    
            self.last_language = language
    
    
    Andres Marafioti's avatar
    Andres Marafioti committed
            self.warmup()
    
        def warmup(self):
            logger.info(f"Warming up {self.__class__.__name__}")
    
            # 2 warmup steps for no compile or compile mode with CUDA graphs capture
            n_steps = 1
            dummy_input = np.array([0] * 512)
    
            for _ in range(n_steps):
                _ = self.model.transcribe(dummy_input)["text"].strip()
    
        def process(self, spoken_prompt):
            logger.debug("infering whisper...")
    
            global pipeline_start
            pipeline_start = perf_counter()
    
    
            if self.start_language != 'auto':
                transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language)
    
                transcription_dict = self.model.transcribe(spoken_prompt)
                language_code = transcription_dict["language"]
                if language_code not in SUPPORTED_LANGUAGES:
                    logger.warning(f"Whisper detected unsupported language: {language_code}")
                    if self.last_language in SUPPORTED_LANGUAGES:  # reprocess with the last language
                        transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language)
                    else:
                        transcription_dict = {"text": "", "language": "en"}
                else:
                    self.last_language = language_code
    
            pred_text = transcription_dict["text"].strip()
            language_code = transcription_dict["language"]
    
    Andres Marafioti's avatar
    Andres Marafioti committed
            torch.mps.empty_cache()
    
    Andres Marafioti's avatar
    Andres Marafioti committed
    
            logger.debug("finished whisper inference")
            console.print(f"[yellow]USER: {pred_text}")
    
            logger.debug(f"Language Code Whisper: {language_code}")
    
            if self.start_language == "auto":
    
                language_code += "-auto"
    
            yield (pred_text, language_code)