From 156e39ebb23e2ed84a8b05d45e642280e3f7ac14 Mon Sep 17 00:00:00 2001
From: Michael Hansen <mike@rhasspy.org>
Date: Sat, 24 Aug 2024 15:21:03 -0500
Subject: [PATCH] Add minimum command seconds to VAD (#124447)

---
 .../components/assist_pipeline/vad.py         | 14 ++++++++++++-
 tests/components/assist_pipeline/test_vad.py  | 20 +++++++++++++++++++
 2 files changed, 33 insertions(+), 1 deletion(-)

diff --git a/homeassistant/components/assist_pipeline/vad.py b/homeassistant/components/assist_pipeline/vad.py
index 8372dbc54c7..4782d14dee4 100644
--- a/homeassistant/components/assist_pipeline/vad.py
+++ b/homeassistant/components/assist_pipeline/vad.py
@@ -78,6 +78,9 @@ class VoiceCommandSegmenter:
     speech_seconds: float = 0.3
     """Seconds of speech before voice command has started."""
 
+    command_seconds: float = 1.0
+    """Minimum number of seconds for a voice command."""
+
     silence_seconds: float = 0.7
     """Seconds of silence after voice command has ended."""
 
@@ -96,6 +99,9 @@ class VoiceCommandSegmenter:
     _speech_seconds_left: float = 0.0
     """Seconds left before considering voice command as started."""
 
+    _command_seconds_left: float = 0.0
+    """Seconds left before voice command could stop."""
+
     _silence_seconds_left: float = 0.0
     """Seconds left before considering voice command as stopped."""
 
@@ -112,6 +118,7 @@ class VoiceCommandSegmenter:
     def reset(self) -> None:
         """Reset all counters and state."""
         self._speech_seconds_left = self.speech_seconds
+        self._command_seconds_left = self.command_seconds - self.speech_seconds
         self._silence_seconds_left = self.silence_seconds
         self._timeout_seconds_left = self.timeout_seconds
         self._reset_seconds_left = self.reset_seconds
@@ -142,6 +149,9 @@ class VoiceCommandSegmenter:
                 if self._speech_seconds_left <= 0:
                     # Inside voice command
                     self.in_command = True
+                    self._command_seconds_left = (
+                        self.command_seconds - self.speech_seconds
+                    )
                     self._silence_seconds_left = self.silence_seconds
                     _LOGGER.debug("Voice command started")
             else:
@@ -154,7 +164,8 @@ class VoiceCommandSegmenter:
             # Silence in command
             self._reset_seconds_left = self.reset_seconds
             self._silence_seconds_left -= chunk_seconds
-            if self._silence_seconds_left <= 0:
+            self._command_seconds_left -= chunk_seconds
+            if (self._silence_seconds_left <= 0) and (self._command_seconds_left <= 0):
                 # Command finished successfully
                 self.reset()
                 _LOGGER.debug("Voice command finished")
@@ -163,6 +174,7 @@ class VoiceCommandSegmenter:
             # Speech in command.
             # Reset silence counter if enough speech.
             self._reset_seconds_left -= chunk_seconds
+            self._command_seconds_left -= chunk_seconds
             if self._reset_seconds_left <= 0:
                 self._silence_seconds_left = self.silence_seconds
                 self._reset_seconds_left = self.reset_seconds
diff --git a/tests/components/assist_pipeline/test_vad.py b/tests/components/assist_pipeline/test_vad.py
index db039ab3140..fda26d2fb94 100644
--- a/tests/components/assist_pipeline/test_vad.py
+++ b/tests/components/assist_pipeline/test_vad.py
@@ -206,3 +206,23 @@ def test_timeout() -> None:
 
     assert not segmenter.process(_ONE_SECOND * 0.5, False)
     assert segmenter.timed_out
+
+
+def test_command_seconds() -> None:
+    """Test minimum number of seconds for voice command."""
+
+    segmenter = VoiceCommandSegmenter(
+        command_seconds=3, speech_seconds=1, silence_seconds=1, reset_seconds=1
+    )
+
+    assert segmenter.process(_ONE_SECOND, True)
+
+    # Silence counts towards total command length
+    assert segmenter.process(_ONE_SECOND * 0.5, False)
+
+    # Enough to finish command now
+    assert segmenter.process(_ONE_SECOND, True)
+    assert segmenter.process(_ONE_SECOND * 0.5, False)
+
+    # Silence to finish
+    assert not segmenter.process(_ONE_SECOND * 0.5, False)
-- 
GitLab