diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 6c18fd96627850237c0916de53cce21bfb55c43a..1761323a60df06639c4fabbde13cda19036f7841 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -74,6 +74,7 @@ def async_register_commands( async_reg(hass, handle_validate_config) async_reg(hass, handle_subscribe_entities) async_reg(hass, handle_supported_brands) + async_reg(hass, handle_supported_features) def pong_message(iden: int) -> dict[str, Any]: @@ -723,3 +724,18 @@ async def handle_supported_brands( raise int_or_exc data[int_or_exc.domain] = int_or_exc.manifest["supported_brands"] connection.send_result(msg["id"], data) + + +@callback +@decorators.websocket_command( + { + vol.Required("type"): "supported_features", + vol.Required("features"): {str: int}, + } +) +def handle_supported_features( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: + """Handle setting supported features.""" + connection.supported_features = msg["features"] + connection.send_result(msg["id"]) diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 87c52288bccc5f1498ae79507ba63ec0d7ebd30f..c344e1c6a9fd4ec7fb8efacfd0fa36bcde7dee3e 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -42,6 +42,7 @@ class ActiveConnection: self.refresh_token_id = refresh_token.id self.subscriptions: dict[Hashable, Callable[[], Any]] = {} self.last_id = 0 + self.supported_features: dict[str, float] = {} current_connection.set(self) def context(self, msg: dict[str, Any]) -> Context: diff --git a/homeassistant/components/websocket_api/const.py b/homeassistant/components/websocket_api/const.py index 60a00126092143a3a8247e34e715ff97a055fd13..6135a821d53daf4f415357b7b3f973300e3f7d99 100644 --- a/homeassistant/components/websocket_api/const.py +++ b/homeassistant/components/websocket_api/const.py @@ -55,3 +55,5 @@ COMPRESSED_STATE_ATTRIBUTES = "a" COMPRESSED_STATE_CONTEXT = "c" COMPRESSED_STATE_LAST_CHANGED = "lc" COMPRESSED_STATE_LAST_UPDATED = "lu" + +FEATURE_COALESCE_MESSAGES = "coalesce_messages" diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index e8972a227c8853737872ea2840261786d7985fe7..7336fa1c0d22a31c1de782a65bd29f5136b63598 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -6,7 +6,7 @@ from collections.abc import Callable from contextlib import suppress import datetime as dt import logging -from typing import Any, Final +from typing import TYPE_CHECKING, Any, Final from aiohttp import WSMsgType, web import async_timeout @@ -16,11 +16,13 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import Event, HomeAssistant, callback from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.event import async_call_later +from homeassistant.helpers.json import json_loads from .auth import AuthPhase, auth_required_message from .const import ( CANCELLATION_ERRORS, DATA_CONNECTIONS, + FEATURE_COALESCE_MESSAGES, MAX_PENDING_MSG, PENDING_MSG_PEAK, PENDING_MSG_PEAK_TIME, @@ -31,6 +33,10 @@ from .const import ( from .error import Disconnect from .messages import message_to_json +if TYPE_CHECKING: + from .connection import ActiveConnection + + _WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection") @@ -67,26 +73,47 @@ class WebSocketHandler: self._writer_task: asyncio.Task | None = None self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)}) self._peak_checker_unsub: Callable[[], None] | None = None + self.connection: ActiveConnection | None = None async def _writer(self) -> None: """Write outgoing messages.""" # Exceptions if Socket disconnected or cancelled by connection handler - with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS): - while not self.wsock.closed: - if (process := await self._to_write.get()) is None: - break - - if not isinstance(process, str): - message: str = process() - else: - message = process - self._logger.debug("Sending %s", message) - await self.wsock.send_str(message) - - # Clean up the peaker checker when we shut down the writer - if self._peak_checker_unsub is not None: - self._peak_checker_unsub() - self._peak_checker_unsub = None + to_write = self._to_write + logger = self._logger + wsock = self.wsock + try: + with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS): + while not self.wsock.closed: + if (process := await to_write.get()) is None: + return + message = process if isinstance(process, str) else process() + + if ( + to_write.empty() + or not self.connection + or FEATURE_COALESCE_MESSAGES + not in self.connection.supported_features + ): + logger.debug("Sending %s", message) + await wsock.send_str(message) + continue + + messages: list[str] = [message] + while not to_write.empty(): + if (process := to_write.get_nowait()) is None: + return + messages.append( + process if isinstance(process, str) else process() + ) + + coalesced_messages = "[" + ",".join(messages) + "]" + self._logger.debug("Sending %s", coalesced_messages) + await self.wsock.send_str(coalesced_messages) + finally: + # Clean up the peaker checker when we shut down the writer + if self._peak_checker_unsub is not None: + self._peak_checker_unsub() + self._peak_checker_unsub = None @callback def _send_message(self, message: str | dict[str, Any] | Callable[[], str]) -> None: @@ -194,13 +221,13 @@ class WebSocketHandler: raise Disconnect try: - msg_data = msg.json() + msg_data = msg.json(loads=json_loads) except ValueError as err: disconnect_warn = "Received invalid JSON." raise Disconnect from err self._logger.debug("Received %s", msg_data) - connection = await auth.async_handle(msg_data) + self.connection = connection = await auth.async_handle(msg_data) self.hass.data[DATA_CONNECTIONS] = ( self.hass.data.get(DATA_CONNECTIONS, 0) + 1 ) @@ -218,13 +245,18 @@ class WebSocketHandler: break try: - msg_data = msg.json() + msg_data = msg.json(loads=json_loads) except ValueError: disconnect_warn = "Received invalid JSON." break self._logger.debug("Received %s", msg_data) - connection.async_handle(msg_data) + if not isinstance(msg_data, list): + connection.async_handle(msg_data) + continue + + for split_msg in msg_data: + connection.async_handle(split_msg) except asyncio.CancelledError: self._logger.info("Connection closed by client") @@ -257,6 +289,8 @@ class WebSocketHandler: if connection is not None: self.hass.data[DATA_CONNECTIONS] -= 1 + self.connection = None + async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED) return wsock diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index f1065061c7354df58815c4171162a7139e6848f6..fe748e2c47c79bdf0d32a0d1e5d88c769defb8d8 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -13,12 +13,13 @@ from homeassistant.components.websocket_api.auth import ( TYPE_AUTH_OK, TYPE_AUTH_REQUIRED, ) -from homeassistant.components.websocket_api.const import URL +from homeassistant.components.websocket_api.const import FEATURE_COALESCE_MESSAGES, URL from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATONS from homeassistant.core import Context, HomeAssistant, State, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import entity from homeassistant.helpers.dispatcher import async_dispatcher_send +from homeassistant.helpers.json import json_loads from homeassistant.loader import async_get_integration from homeassistant.setup import DATA_SETUP_TIME, async_setup_component @@ -1788,3 +1789,186 @@ async def test_supported_brands(hass, websocket_client): "hello": "World", }, } + + +async def test_message_coalescing(hass, websocket_client, hass_admin_user): + """Test enabling message coalescing.""" + await websocket_client.send_json( + { + "id": 1, + "type": "supported_features", + "features": {FEATURE_COALESCE_MESSAGES: 1}, + } + ) + hass.states.async_set("light.permitted", "on", {"color": "red"}) + + data = await websocket_client.receive_str() + msg = json_loads(data) + assert msg["id"] == 1 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + + await websocket_client.send_json({"id": 7, "type": "subscribe_entities"}) + + data = await websocket_client.receive_str() + msgs = json_loads(data) + msg = msgs.pop(0) + assert msg["id"] == 7 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + + msg = msgs.pop(0) + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "a": { + "light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"} + } + } + + hass.states.async_set("light.permitted", "on", {"color": "yellow"}) + hass.states.async_set("light.permitted", "on", {"color": "green"}) + hass.states.async_set("light.permitted", "on", {"color": "blue"}) + + data = await websocket_client.receive_str() + msgs = json_loads(data) + + msg = msgs.pop(0) + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "c": {"light.permitted": {"+": {"a": {"color": "yellow"}, "c": ANY, "lu": ANY}}} + } + + msg = msgs.pop(0) + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "c": {"light.permitted": {"+": {"a": {"color": "green"}, "c": ANY, "lu": ANY}}} + } + + msg = msgs.pop(0) + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}} + } + + hass.states.async_set("light.permitted", "on", {"color": "yellow"}) + hass.states.async_set("light.permitted", "on", {"color": "green"}) + hass.states.async_set("light.permitted", "on", {"color": "blue"}) + await websocket_client.close() + await hass.async_block_till_done() + + +async def test_message_coalescing_not_supported_by_websocket_client( + hass, websocket_client, hass_admin_user +): + """Test enabling message coalescing not supported by websocket client.""" + await websocket_client.send_json({"id": 7, "type": "subscribe_entities"}) + + data = await websocket_client.receive_str() + msg = json_loads(data) + assert msg["id"] == 7 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + + hass.states.async_set("light.permitted", "on", {"color": "red"}) + hass.states.async_set("light.permitted", "on", {"color": "blue"}) + + data = await websocket_client.receive_str() + msg = json_loads(data) + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == {"a": {}} + + data = await websocket_client.receive_str() + msg = json_loads(data) + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "a": { + "light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"} + } + } + + data = await websocket_client.receive_str() + msg = json_loads(data) + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}} + } + await websocket_client.close() + await hass.async_block_till_done() + + +async def test_client_message_coalescing(hass, websocket_client, hass_admin_user): + """Test client message coalescing.""" + await websocket_client.send_json( + [ + { + "id": 1, + "type": "supported_features", + "features": {FEATURE_COALESCE_MESSAGES: 1}, + }, + {"id": 7, "type": "subscribe_entities"}, + ] + ) + hass.states.async_set("light.permitted", "on", {"color": "red"}) + + data = await websocket_client.receive_str() + msgs = json_loads(data) + + msg = msgs.pop(0) + assert msg["id"] == 1 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + + msg = msgs.pop(0) + assert msg["id"] == 7 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + + msg = msgs.pop(0) + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "a": { + "light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"} + } + } + + hass.states.async_set("light.permitted", "on", {"color": "yellow"}) + hass.states.async_set("light.permitted", "on", {"color": "green"}) + hass.states.async_set("light.permitted", "on", {"color": "blue"}) + + data = await websocket_client.receive_str() + msgs = json_loads(data) + + msg = msgs.pop(0) + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "c": {"light.permitted": {"+": {"a": {"color": "yellow"}, "c": ANY, "lu": ANY}}} + } + + msg = msgs.pop(0) + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "c": {"light.permitted": {"+": {"a": {"color": "green"}, "c": ANY, "lu": ANY}}} + } + + msg = msgs.pop(0) + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"] == { + "c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}} + } + + hass.states.async_set("light.permitted", "on", {"color": "yellow"}) + hass.states.async_set("light.permitted", "on", {"color": "green"}) + hass.states.async_set("light.permitted", "on", {"color": "blue"}) + await websocket_client.close() + await hass.async_block_till_done() diff --git a/tests/conftest.py b/tests/conftest.py index 0c0a654059b4c60564647c4f573cdbf7ea92a131..889568127cb6d35a9ce2f816b96d10820ece6f7f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,15 +2,25 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable, Generator from contextlib import asynccontextmanager import functools +from json import JSONDecoder, loads import logging import ssl import threading +from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch -from aiohttp.test_utils import make_mocked_request +from aiohttp import client +from aiohttp.pytest_plugin import AiohttpClient +from aiohttp.test_utils import ( + BaseTestServer, + TestClient, + TestServer, + make_mocked_request, +) +from aiohttp.web import Application import freezegun import multidict import pytest @@ -57,6 +67,7 @@ from tests.components.recorder.common import ( # noqa: E402, isort:skip async_recorder_block_till_done, ) + _LOGGER = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG) @@ -203,6 +214,97 @@ def load_registries(): return True +class CoalescingResponse(client.ClientWebSocketResponse): + """ClientWebSocketResponse client that mimics the websocket js code.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Init the ClientWebSocketResponse.""" + super().__init__(*args, **kwargs) + self._recv_buffer: list[Any] = [] + + async def receive_json( + self, + *, + loads: JSONDecoder = loads, + timeout: float | None = None, + ) -> Any: + """receive_json or from buffer.""" + if self._recv_buffer: + return self._recv_buffer.pop(0) + data = await self.receive_str(timeout=timeout) + decoded = loads(data) + if isinstance(decoded, list): + self._recv_buffer = decoded + return self._recv_buffer.pop(0) + return decoded + + +class CoalescingClient(TestClient): + """Client that mimics the websocket js code.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Init TestClient.""" + super().__init__(*args, ws_response_class=CoalescingResponse, **kwargs) + + +@pytest.fixture +def aiohttp_client_cls(): + """Override the test class for aiohttp.""" + return CoalescingClient + + +@pytest.fixture +def aiohttp_client( + loop: asyncio.AbstractEventLoop, +) -> Generator[AiohttpClient, None, None]: + """Override the default aiohttp_client since 3.x does not support aiohttp_client_cls. + + Remove this when upgrading to 4.x as aiohttp_client_cls + will do the same thing + + aiohttp_client(app, **kwargs) + aiohttp_client(server, **kwargs) + aiohttp_client(raw_server, **kwargs) + """ + clients = [] + + async def go( + __param: Application | BaseTestServer, + *args: Any, + server_kwargs: dict[str, Any] | None = None, + **kwargs: Any, + ) -> TestClient: + + if isinstance(__param, Callable) and not isinstance( # type: ignore[arg-type] + __param, (Application, BaseTestServer) + ): + __param = __param(loop, *args, **kwargs) + kwargs = {} + else: + assert not args, "args should be empty" + + if isinstance(__param, Application): + server_kwargs = server_kwargs or {} + server = TestServer(__param, loop=loop, **server_kwargs) + client = CoalescingClient(server, loop=loop, **kwargs) + elif isinstance(__param, BaseTestServer): + client = TestClient(__param, loop=loop, **kwargs) + else: + raise ValueError("Unknown argument type: %r" % type(__param)) + + await client.start_server() + clients.append(client) + return client + + yield go + + async def finalize() -> None: + while clients: + await clients.pop().close() + + loop.run_until_complete(finalize()) + + @pytest.fixture def hass(loop, load_registries, hass_storage, request): """Fixture to provide a test instance of Home Assistant."""