diff --git a/VAD/vad_handler.py b/VAD/vad_handler.py index 3a36498f6b66f0b0bbb64ddd2a0289fbd3696966..e70a44c2afc6d6a1de456d7c0fb0dd81d7e65e95 100644 --- a/VAD/vad_handler.py +++ b/VAD/vad_handler.py @@ -1,3 +1,4 @@ +import torchaudio from VAD.vad_iterator import VADIterator from baseHandler import BaseHandler import numpy as np @@ -5,7 +6,7 @@ import torch from rich.console import Console from utils.utils import int2float - +from df.enhance import enhance, init_df import logging logger = logging.getLogger(__name__) @@ -28,6 +29,7 @@ class VADHandler(BaseHandler): min_speech_ms=500, max_speech_ms=float("inf"), speech_pad_ms=30, + audio_enhancement=True ): self.should_listen = should_listen self.sample_rate = sample_rate @@ -42,6 +44,9 @@ class VADHandler(BaseHandler): min_silence_duration_ms=min_silence_ms, speech_pad_ms=speech_pad_ms, ) + self.audio_enhancement = audio_enhancement + if audio_enhancement: + self.enhanced_model, self.df_state, _ = init_df() def process(self, audio_chunk): audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16) @@ -58,4 +63,13 @@ class VADHandler(BaseHandler): else: self.should_listen.clear() logger.debug("Stop listening") + if self.audio_enhancement: + if self.sample_rate != self.df_state.sr(): + audio_float32 = torchaudio.functional.resample(torch.from_numpy(array),orig_freq=self.sample_rate, + new_freq=self.df_state.sr()) + enhanced = enhance(self.enhanced_model, self.df_state, audio_float32.unsqueeze(0)) + enhanced = torchaudio.functional.resample(enhanced, orig_freq=self.df_state.sr(),new_freq=self.sample_rate) + else: + enhanced = enhance(self.enhanced_model, self.df_state, audio_float32) + array = enhanced.numpy().squeeze() yield array diff --git a/arguments_classes/vad_arguments.py b/arguments_classes/vad_arguments.py index 450229c29114ab787b40c39e8f250840a1ec4a7d..41f4b6d6a252a22d7acf481fc26de0bfbf989a91 100644 --- a/arguments_classes/vad_arguments.py +++ b/arguments_classes/vad_arguments.py @@ -39,3 +39,9 @@ class VADHandlerArguments: "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms." }, ) + audio_enhancement:bool = field( + default=True, + metadata={ + "help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is True." + }, + ) diff --git a/requirements.txt b/requirements.txt index 78a37b65ef6ed088a02c52180fbc3a598ce19726..fba30cd7f5e716797d29b3dd5890fd1a610d06a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ torch==2.4.0 sounddevice==0.5.0 ChatTTS>=0.1.1 funasr>=1.1.6 -modelscope>=1.17.1 \ No newline at end of file +modelscope>=1.17.1 +deepfilternet>=0.5.6 diff --git a/requirements_mac.txt b/requirements_mac.txt index 3bf9cb757d63ad88264f4aebe8bb6bb506eabe04..4a1c5cbb4a101ce611a2b81e4d52b73259782a0c 100644 --- a/requirements_mac.txt +++ b/requirements_mac.txt @@ -8,3 +8,5 @@ mlx-lm>=0.14.0 ChatTTS>=0.1.1 funasr>=1.1.6 modelscope>=1.17.1 +deepfilternet>=0.5.6 +