Skip to content
Snippets Groups Projects
Unverified Commit 68384bba authored by Jesse Hills's avatar Jesse Hills Committed by GitHub
Browse files

Send/receive Voice Assistant audio via ESPHome native API (#114800)


* Protobuf audio test

* Remove extraneous code

* Rework voice assistant pipeline

* Move variables

* Fix reading flags

* Dont directly put to queue

* Bump aioesphomeapi to 24.0.0

* Update tests

- Add more tests for API pipeline
- Convert some udp tests to use api pipeline
- Update fixtures for new device info flags

* Fix bad merge

---------

Co-authored-by: default avatarMichael Hansen <mike@rhasspy.org>
parent cad4c3c0
No related branches found
No related tags found
No related merge requests found
......@@ -33,7 +33,9 @@ async def async_setup_entry(
entry_data = DomainData.get(hass).get_entry_data(entry)
assert entry_data.device_info is not None
if entry_data.device_info.voice_assistant_version:
if entry_data.device_info.voice_assistant_feature_flags_compat(
entry_data.api_version
):
async_add_entities([EsphomeAssistInProgressBinarySensor(entry_data)])
......
......@@ -257,7 +257,9 @@ class RuntimeEntryData:
if async_get_dashboard(hass):
needed_platforms.add(Platform.UPDATE)
if self.device_info and self.device_info.voice_assistant_version:
if self.device_info and self.device_info.voice_assistant_feature_flags_compat(
self.api_version
):
needed_platforms.add(Platform.BINARY_SENSOR)
needed_platforms.add(Platform.SELECT)
......
......@@ -21,6 +21,7 @@ from aioesphomeapi import (
UserService,
UserServiceArgType,
VoiceAssistantAudioSettings,
VoiceAssistantFeature,
)
from awesomeversion import AwesomeVersion
import voluptuous as vol
......@@ -72,7 +73,11 @@ from .domain_data import DomainData
# Import config flow so that it's added to the registry
from .entry_data import RuntimeEntryData
from .voice_assistant import VoiceAssistantUDPServer
from .voice_assistant import (
VoiceAssistantAPIPipeline,
VoiceAssistantPipeline,
VoiceAssistantUDPPipeline,
)
_LOGGER = logging.getLogger(__name__)
......@@ -143,7 +148,7 @@ class ESPHomeManager:
"cli",
"device_id",
"domain_data",
"voice_assistant_udp_server",
"voice_assistant_pipeline",
"reconnect_logic",
"zeroconf_instance",
"entry_data",
......@@ -168,7 +173,7 @@ class ESPHomeManager:
self.cli = cli
self.device_id: str | None = None
self.domain_data = domain_data
self.voice_assistant_udp_server: VoiceAssistantUDPServer | None = None
self.voice_assistant_pipeline: VoiceAssistantPipeline | None = None
self.reconnect_logic: ReconnectLogic | None = None
self.zeroconf_instance = zeroconf_instance
self.entry_data = entry_data
......@@ -327,9 +332,10 @@ class ESPHomeManager:
def _handle_pipeline_finished(self) -> None:
self.entry_data.async_set_assist_pipeline_state(False)
if self.voice_assistant_udp_server is not None:
self.voice_assistant_udp_server.close()
self.voice_assistant_udp_server = None
if self.voice_assistant_pipeline is not None:
if isinstance(self.voice_assistant_pipeline, VoiceAssistantUDPPipeline):
self.voice_assistant_pipeline.close()
self.voice_assistant_pipeline = None
async def _handle_pipeline_start(
self,
......@@ -339,38 +345,60 @@ class ESPHomeManager:
wake_word_phrase: str | None,
) -> int | None:
"""Start a voice assistant pipeline."""
if self.voice_assistant_udp_server is not None:
if self.voice_assistant_pipeline is not None:
_LOGGER.warning("Voice assistant UDP server was not stopped")
self.voice_assistant_udp_server.stop()
self.voice_assistant_udp_server = None
self.voice_assistant_pipeline.stop()
self.voice_assistant_pipeline = None
hass = self.hass
self.voice_assistant_udp_server = VoiceAssistantUDPServer(
hass,
self.entry_data,
self.cli.send_voice_assistant_event,
self._handle_pipeline_finished,
)
port = await self.voice_assistant_udp_server.start_server()
assert self.entry_data.device_info is not None
if (
self.entry_data.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
& VoiceAssistantFeature.API_AUDIO
):
self.voice_assistant_pipeline = VoiceAssistantAPIPipeline(
hass,
self.entry_data,
self.cli.send_voice_assistant_event,
self._handle_pipeline_finished,
self.cli,
)
port = 0
else:
self.voice_assistant_pipeline = VoiceAssistantUDPPipeline(
hass,
self.entry_data,
self.cli.send_voice_assistant_event,
self._handle_pipeline_finished,
)
port = await self.voice_assistant_pipeline.start_server()
assert self.device_id is not None, "Device ID must be set"
hass.async_create_background_task(
self.voice_assistant_udp_server.run_pipeline(
self.voice_assistant_pipeline.run_pipeline(
device_id=self.device_id,
conversation_id=conversation_id or None,
flags=flags,
audio_settings=audio_settings,
wake_word_phrase=wake_word_phrase,
),
"esphome.voice_assistant_udp_server.run_pipeline",
"esphome.voice_assistant_pipeline.run_pipeline",
)
return port
async def _handle_pipeline_stop(self) -> None:
"""Stop a voice assistant pipeline."""
if self.voice_assistant_udp_server is not None:
self.voice_assistant_udp_server.stop()
if self.voice_assistant_pipeline is not None:
self.voice_assistant_pipeline.stop()
async def _handle_audio(self, data: bytes) -> None:
if self.voice_assistant_pipeline is None:
return
assert isinstance(self.voice_assistant_pipeline, VoiceAssistantAPIPipeline)
self.voice_assistant_pipeline.receive_audio_bytes(data)
async def on_connect(self) -> None:
"""Subscribe to states and list entities on successful API login."""
......@@ -472,13 +500,23 @@ class ESPHomeManager:
)
)
if device_info.voice_assistant_version:
entry_data.disconnect_callbacks.add(
cli.subscribe_voice_assistant(
self._handle_pipeline_start,
self._handle_pipeline_stop,
flags = device_info.voice_assistant_feature_flags_compat(api_version)
if flags:
if flags & VoiceAssistantFeature.API_AUDIO:
entry_data.disconnect_callbacks.add(
cli.subscribe_voice_assistant(
handle_start=self._handle_pipeline_start,
handle_stop=self._handle_pipeline_stop,
handle_audio=self._handle_audio,
)
)
else:
entry_data.disconnect_callbacks.add(
cli.subscribe_voice_assistant(
handle_start=self._handle_pipeline_start,
handle_stop=self._handle_pipeline_stop,
)
)
)
cli.subscribe_states(entry_data.async_update_state)
cli.subscribe_service_calls(self.async_on_service_call)
......
......@@ -15,7 +15,7 @@
"iot_class": "local_push",
"loggers": ["aioesphomeapi", "noiseprotocol", "bleak_esphome"],
"requirements": [
"aioesphomeapi==23.2.0",
"aioesphomeapi==24.0.0",
"esphome-dashboard-api==1.2.3",
"bleak-esphome==1.0.0"
],
......
......@@ -42,7 +42,9 @@ async def async_setup_entry(
entry_data = DomainData.get(hass).get_entry_data(entry)
assert entry_data.device_info is not None
if entry_data.device_info.voice_assistant_version:
if entry_data.device_info.voice_assistant_feature_flags_compat(
entry_data.api_version
):
async_add_entities(
[
EsphomeAssistPipelineSelect(hass, entry_data),
......
......@@ -11,9 +11,11 @@ from typing import cast
import wave
from aioesphomeapi import (
APIClient,
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
VoiceAssistantFeature,
)
from homeassistant.components import stt, tts
......@@ -64,13 +66,11 @@ _VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
)
class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
"""Receive UDP packets and forward them to the voice assistant."""
class VoiceAssistantPipeline:
"""Base abstract pipeline class."""
started = False
stop_requested = False
transport: asyncio.DatagramTransport | None = None
remote_addr: tuple[str, int] | None = None
def __init__(
self,
......@@ -79,12 +79,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
handle_finished: Callable[[], None],
) -> None:
"""Initialize UDP receiver."""
"""Initialize the pipeline."""
self.context = Context()
self.hass = hass
assert entry_data.device_info is not None
self.entry_data = entry_data
assert entry_data.device_info is not None
self.device_info = entry_data.device_info
self.queue: asyncio.Queue[bytes] = asyncio.Queue()
......@@ -95,69 +94,9 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
@property
def is_running(self) -> bool:
"""True if the UDP server is started and hasn't been asked to stop."""
"""True if the pipeline is started and hasn't been asked to stop."""
return self.started and (not self.stop_requested)
async def start_server(self) -> int:
"""Start accepting connections."""
def accept_connection() -> VoiceAssistantUDPServer:
"""Accept connection."""
if self.started:
raise RuntimeError("Can only start once")
if self.stop_requested:
raise RuntimeError("No longer accepting connections")
self.started = True
return self
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", UDP_PORT))
await asyncio.get_running_loop().create_datagram_endpoint(
accept_connection, sock=sock
)
return cast(int, sock.getsockname()[1])
@callback
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Store transport for later use."""
self.transport = cast(asyncio.DatagramTransport, transport)
@callback
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""Handle incoming UDP packet."""
if not self.is_running:
return
if self.remote_addr is None:
self.remote_addr = addr
self.queue.put_nowait(data)
def error_received(self, exc: Exception) -> None:
"""Handle when a send or receive operation raises an OSError.
(Other than BlockingIOError or InterruptedError.)
"""
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
self.handle_finished()
@callback
def stop(self) -> None:
"""Stop the receiver."""
self.queue.put_nowait(b"")
self.close()
def close(self) -> None:
"""Close the receiver."""
self.started = False
self.stop_requested = True
if self.transport is not None:
self.transport.close()
async def _iterate_packets(self) -> AsyncIterable[bytes]:
"""Iterate over incoming packets."""
while data := await self.queue.get():
......@@ -198,7 +137,12 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
url = async_process_play_media_url(self.hass, path)
data_to_send = {"url": url}
if self.device_info.voice_assistant_version >= 2:
if (
self.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
& VoiceAssistantFeature.SPEAKER
):
media_id = tts_output["media_id"]
self._tts_task = self.hass.async_create_background_task(
self._send_tts(media_id), "esphome_voice_assistant_tts"
......@@ -243,9 +187,15 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
if audio_settings is None or audio_settings.volume_multiplier == 0:
audio_settings = VoiceAssistantAudioSettings()
tts_audio_output = (
"wav" if self.device_info.voice_assistant_version >= 2 else "mp3"
)
if (
self.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
& VoiceAssistantFeature.SPEAKER
):
tts_audio_output = "wav"
else:
tts_audio_output = "mp3"
_LOGGER.debug("Starting pipeline")
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
......@@ -315,7 +265,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self.handle_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
try:
if (not self.is_running) or (self.transport is None):
if not self.is_running:
return
extension, data = await tts.async_get_media_source_audio(
......@@ -358,16 +308,133 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
samples_in_chunk = len(chunk) // bytes_per_sample
samples_left -= samples_in_chunk
self.transport.sendto(chunk, self.remote_addr)
self.send_audio_bytes(chunk)
await asyncio.sleep(
samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.9
)
sample_offset += samples_in_chunk
finally:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
)
self._tts_task = None
self._tts_done.set()
def send_audio_bytes(self, data: bytes) -> None:
"""Send bytes to the device."""
raise NotImplementedError
def stop(self) -> None:
"""Stop the pipeline."""
self.queue.put_nowait(b"")
class VoiceAssistantUDPPipeline(asyncio.DatagramProtocol, VoiceAssistantPipeline):
"""Receive UDP packets and forward them to the voice assistant."""
transport: asyncio.DatagramTransport | None = None
remote_addr: tuple[str, int] | None = None
async def start_server(self) -> int:
"""Start accepting connections."""
def accept_connection() -> VoiceAssistantUDPPipeline:
"""Accept connection."""
if self.started:
raise RuntimeError("Can only start once")
if self.stop_requested:
raise RuntimeError("No longer accepting connections")
self.started = True
return self
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", UDP_PORT))
await asyncio.get_running_loop().create_datagram_endpoint(
accept_connection, sock=sock
)
return cast(int, sock.getsockname()[1])
@callback
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Store transport for later use."""
self.transport = cast(asyncio.DatagramTransport, transport)
@callback
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""Handle incoming UDP packet."""
if not self.is_running:
return
if self.remote_addr is None:
self.remote_addr = addr
self.queue.put_nowait(data)
def error_received(self, exc: Exception) -> None:
"""Handle when a send or receive operation raises an OSError.
(Other than BlockingIOError or InterruptedError.)
"""
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
self.handle_finished()
@callback
def stop(self) -> None:
"""Stop the receiver."""
super().stop()
self.close()
def close(self) -> None:
"""Close the receiver."""
self.started = False
self.stop_requested = True
if self.transport is not None:
self.transport.close()
def send_audio_bytes(self, data: bytes) -> None:
"""Send bytes to the device via UDP."""
if self.transport is None:
_LOGGER.error("No transport to send audio to")
return
self.transport.sendto(data, self.remote_addr)
class VoiceAssistantAPIPipeline(VoiceAssistantPipeline):
"""Send audio to the voice assistant via the API."""
def __init__(
self,
hass: HomeAssistant,
entry_data: RuntimeEntryData,
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
handle_finished: Callable[[], None],
api_client: APIClient,
) -> None:
"""Initialize the pipeline."""
super().__init__(hass, entry_data, handle_event, handle_finished)
self.api_client = api_client
self.started = True
def send_audio_bytes(self, data: bytes) -> None:
"""Send bytes to the device via the API."""
self.api_client.send_voice_assistant_audio(data)
@callback
def receive_audio_bytes(self, data: bytes) -> None:
"""Receive audio bytes from the device."""
if not self.is_running:
return
self.queue.put_nowait(data)
@callback
def stop(self) -> None:
"""Stop the pipeline."""
super().stop()
self.started = False
self.stop_requested = True
......@@ -242,7 +242,7 @@ aioelectricitymaps==0.4.0
aioemonitor==1.0.5
# homeassistant.components.esphome
aioesphomeapi==23.2.0
aioesphomeapi==24.0.0
# homeassistant.components.flo
aioflo==2021.11.0
......
......@@ -221,7 +221,7 @@ aioelectricitymaps==0.4.0
aioemonitor==1.0.5
# homeassistant.components.esphome
aioesphomeapi==23.2.0
aioesphomeapi==24.0.0
# homeassistant.components.flo
aioflo==2021.11.0
......
......@@ -18,6 +18,7 @@ from aioesphomeapi import (
HomeassistantServiceCall,
ReconnectLogic,
UserService,
VoiceAssistantFeature,
)
import pytest
from zeroconf import Zeroconf
......@@ -354,10 +355,16 @@ async def mock_voice_assistant_entry(
):
"""Set up an ESPHome entry with voice assistant."""
async def _mock_voice_assistant_entry(version: int) -> MockConfigEntry:
async def _mock_voice_assistant_entry(
voice_assistant_feature_flags: VoiceAssistantFeature,
) -> MockConfigEntry:
return (
await _mock_generic_device_entry(
hass, mock_client, {"voice_assistant_version": version}, ([], []), []
hass,
mock_client,
{"voice_assistant_feature_flags": voice_assistant_feature_flags},
([], []),
[],
)
).entry
......@@ -367,13 +374,28 @@ async def mock_voice_assistant_entry(
@pytest.fixture
async def mock_voice_assistant_v1_entry(mock_voice_assistant_entry) -> MockConfigEntry:
"""Set up an ESPHome entry with voice assistant."""
return await mock_voice_assistant_entry(version=1)
return await mock_voice_assistant_entry(
voice_assistant_feature_flags=VoiceAssistantFeature.VOICE_ASSISTANT
)
@pytest.fixture
async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry:
"""Set up an ESPHome entry with voice assistant."""
return await mock_voice_assistant_entry(version=2)
return await mock_voice_assistant_entry(
voice_assistant_feature_flags=VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.SPEAKER
)
@pytest.fixture
async def mock_voice_assistant_api_entry(mock_voice_assistant_entry) -> MockConfigEntry:
"""Set up an ESPHome entry with voice assistant."""
return await mock_voice_assistant_entry(
voice_assistant_feature_flags=VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.SPEAKER
| VoiceAssistantFeature.API_AUDIO
)
@pytest.fixture
......
......@@ -94,7 +94,8 @@ async def test_diagnostics_with_bluetooth(
"project_version": "",
"suggested_area": "",
"uses_password": False,
"voice_assistant_version": 0,
"legacy_voice_assistant_version": 0,
"voice_assistant_feature_flags": 0,
"webserver_port": 0,
},
"services": [],
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment