diff --git a/s2s_pipeline.py b/s2s_pipeline.py index f7af261c0afde3bd187eaaec48e4680f0badc8d0..b7186fc59d582cfb0be61cd70ebf23e1df508ff4 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -13,6 +13,7 @@ from time import perf_counter import numpy as np import torch +import nltk from nltk.tokenize import sent_tokenize from rich.console import Console from transformers import ( @@ -36,6 +37,11 @@ from utils import ( next_power_of_2 ) +# Ensure that the necessary NLTK resources are available +try: + nltk.data.find('tokenizers/punkt_tab') +except LookupError: + nltk.download('punkt_tab') # caching allows ~50% compilation time reduction # see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma @@ -337,9 +343,9 @@ class VADHandler(BaseHandler): array = torch.cat(vad_output).cpu().numpy() duration_ms = len(array) / self.sample_rate * 1000 if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms: - logger.debug(f"audio input of duration: {len(array) / self._sample_rate}s, skipping") + logger.debug(f"audio input of duration: {len(array) / self.sample_rate}s, skipping") else: - self._should_listen.clear() + self.should_listen.clear() logger.debug("Stop listening") yield array