Skip to content
Snippets Groups Projects
vad_handler.py 3.25 KiB
Newer Older
  • Learn to ignore specific revisions
  • from VAD.vad_iterator import VADIterator
    from baseHandler import BaseHandler
    import numpy as np
    import torch
    from rich.console import Console
    
    from utils.utils import int2float
    
    from df.enhance import enhance, init_df
    
    import logging
    
    logger = logging.getLogger(__name__)
    
    console = Console()
    
    
    class VADHandler(BaseHandler):
        """
        Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
        to the following part.
        """
    
        def setup(
            self,
            should_listen,
            thresh=0.3,
            sample_rate=16000,
            min_silence_ms=1000,
            min_speech_ms=500,
            max_speech_ms=float("inf"),
            speech_pad_ms=30,
    
            audio_enhancement=False,
    
        ):
            self.should_listen = should_listen
            self.sample_rate = sample_rate
            self.min_silence_ms = min_silence_ms
            self.min_speech_ms = min_speech_ms
            self.max_speech_ms = max_speech_ms
            self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad")
            self.iterator = VADIterator(
                self.model,
                threshold=thresh,
                sampling_rate=sample_rate,
                min_silence_duration_ms=min_silence_ms,
                speech_pad_ms=speech_pad_ms,
            )
    
            self.audio_enhancement = audio_enhancement
            if audio_enhancement:
                self.enhanced_model, self.df_state, _ = init_df()
    
    
        def process(self, audio_chunk):
            audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16)
            audio_float32 = int2float(audio_int16)
            vad_output = self.iterator(torch.from_numpy(audio_float32))
            if vad_output is not None and len(vad_output) != 0:
                logger.debug("VAD: end of speech detected")
                array = torch.cat(vad_output).cpu().numpy()
                duration_ms = len(array) / self.sample_rate * 1000
                if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
                    logger.debug(
                        f"audio input of duration: {len(array) / self.sample_rate}s, skipping"
                    )
                else:
                    self.should_listen.clear()
                    logger.debug("Stop listening")
    
                    if self.audio_enhancement:
                        if self.sample_rate != self.df_state.sr():
    
                            audio_float32 = torchaudio.functional.resample(
                                torch.from_numpy(array),
                                orig_freq=self.sample_rate,
                                new_freq=self.df_state.sr(),
                            )
                            enhanced = enhance(
                                self.enhanced_model,
                                self.df_state,
                                audio_float32.unsqueeze(0),
                            )
                            enhanced = torchaudio.functional.resample(
                                enhanced,
                                orig_freq=self.df_state.sr(),
                                new_freq=self.sample_rate,
                            )
    
                            enhanced = enhance(
                                self.enhanced_model, self.df_state, audio_float32
                            )
    
                        array = enhanced.numpy().squeeze()