Skip to content
Snippets Groups Projects
Unverified Commit f6a03625 authored by J. Nick Koston's avatar J. Nick Koston Committed by GitHub
Browse files

Implement websocket message coalescing (#77238)

parent 2161b6f0
No related branches found
No related tags found
No related merge requests found
......@@ -74,6 +74,7 @@ def async_register_commands(
async_reg(hass, handle_validate_config)
async_reg(hass, handle_subscribe_entities)
async_reg(hass, handle_supported_brands)
async_reg(hass, handle_supported_features)
def pong_message(iden: int) -> dict[str, Any]:
......@@ -723,3 +724,18 @@ async def handle_supported_brands(
raise int_or_exc
data[int_or_exc.domain] = int_or_exc.manifest["supported_brands"]
connection.send_result(msg["id"], data)
@callback
@decorators.websocket_command(
{
vol.Required("type"): "supported_features",
vol.Required("features"): {str: int},
}
)
def handle_supported_features(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle setting supported features."""
connection.supported_features = msg["features"]
connection.send_result(msg["id"])
......@@ -42,6 +42,7 @@ class ActiveConnection:
self.refresh_token_id = refresh_token.id
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
self.last_id = 0
self.supported_features: dict[str, float] = {}
current_connection.set(self)
def context(self, msg: dict[str, Any]) -> Context:
......
......@@ -55,3 +55,5 @@ COMPRESSED_STATE_ATTRIBUTES = "a"
COMPRESSED_STATE_CONTEXT = "c"
COMPRESSED_STATE_LAST_CHANGED = "lc"
COMPRESSED_STATE_LAST_UPDATED = "lu"
FEATURE_COALESCE_MESSAGES = "coalesce_messages"
......@@ -6,7 +6,7 @@ from collections.abc import Callable
from contextlib import suppress
import datetime as dt
import logging
from typing import Any, Final
from typing import TYPE_CHECKING, Any, Final
from aiohttp import WSMsgType, web
import async_timeout
......@@ -16,11 +16,13 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.json import json_loads
from .auth import AuthPhase, auth_required_message
from .const import (
CANCELLATION_ERRORS,
DATA_CONNECTIONS,
FEATURE_COALESCE_MESSAGES,
MAX_PENDING_MSG,
PENDING_MSG_PEAK,
PENDING_MSG_PEAK_TIME,
......@@ -31,6 +33,10 @@ from .const import (
from .error import Disconnect
from .messages import message_to_json
if TYPE_CHECKING:
from .connection import ActiveConnection
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
......@@ -67,26 +73,47 @@ class WebSocketHandler:
self._writer_task: asyncio.Task | None = None
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
self._peak_checker_unsub: Callable[[], None] | None = None
self.connection: ActiveConnection | None = None
async def _writer(self) -> None:
"""Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
while not self.wsock.closed:
if (process := await self._to_write.get()) is None:
break
if not isinstance(process, str):
message: str = process()
else:
message = process
self._logger.debug("Sending %s", message)
await self.wsock.send_str(message)
# Clean up the peaker checker when we shut down the writer
if self._peak_checker_unsub is not None:
self._peak_checker_unsub()
self._peak_checker_unsub = None
to_write = self._to_write
logger = self._logger
wsock = self.wsock
try:
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
while not self.wsock.closed:
if (process := await to_write.get()) is None:
return
message = process if isinstance(process, str) else process()
if (
to_write.empty()
or not self.connection
or FEATURE_COALESCE_MESSAGES
not in self.connection.supported_features
):
logger.debug("Sending %s", message)
await wsock.send_str(message)
continue
messages: list[str] = [message]
while not to_write.empty():
if (process := to_write.get_nowait()) is None:
return
messages.append(
process if isinstance(process, str) else process()
)
coalesced_messages = "[" + ",".join(messages) + "]"
self._logger.debug("Sending %s", coalesced_messages)
await self.wsock.send_str(coalesced_messages)
finally:
# Clean up the peaker checker when we shut down the writer
if self._peak_checker_unsub is not None:
self._peak_checker_unsub()
self._peak_checker_unsub = None
@callback
def _send_message(self, message: str | dict[str, Any] | Callable[[], str]) -> None:
......@@ -194,13 +221,13 @@ class WebSocketHandler:
raise Disconnect
try:
msg_data = msg.json()
msg_data = msg.json(loads=json_loads)
except ValueError as err:
disconnect_warn = "Received invalid JSON."
raise Disconnect from err
self._logger.debug("Received %s", msg_data)
connection = await auth.async_handle(msg_data)
self.connection = connection = await auth.async_handle(msg_data)
self.hass.data[DATA_CONNECTIONS] = (
self.hass.data.get(DATA_CONNECTIONS, 0) + 1
)
......@@ -218,13 +245,18 @@ class WebSocketHandler:
break
try:
msg_data = msg.json()
msg_data = msg.json(loads=json_loads)
except ValueError:
disconnect_warn = "Received invalid JSON."
break
self._logger.debug("Received %s", msg_data)
connection.async_handle(msg_data)
if not isinstance(msg_data, list):
connection.async_handle(msg_data)
continue
for split_msg in msg_data:
connection.async_handle(split_msg)
except asyncio.CancelledError:
self._logger.info("Connection closed by client")
......@@ -257,6 +289,8 @@ class WebSocketHandler:
if connection is not None:
self.hass.data[DATA_CONNECTIONS] -= 1
self.connection = None
async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED)
return wsock
......@@ -13,12 +13,13 @@ from homeassistant.components.websocket_api.auth import (
TYPE_AUTH_OK,
TYPE_AUTH_REQUIRED,
)
from homeassistant.components.websocket_api.const import URL
from homeassistant.components.websocket_api.const import FEATURE_COALESCE_MESSAGES, URL
from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATONS
from homeassistant.core import Context, HomeAssistant, State, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.json import json_loads
from homeassistant.loader import async_get_integration
from homeassistant.setup import DATA_SETUP_TIME, async_setup_component
......@@ -1788,3 +1789,186 @@ async def test_supported_brands(hass, websocket_client):
"hello": "World",
},
}
async def test_message_coalescing(hass, websocket_client, hass_admin_user):
"""Test enabling message coalescing."""
await websocket_client.send_json(
{
"id": 1,
"type": "supported_features",
"features": {FEATURE_COALESCE_MESSAGES: 1},
}
)
hass.states.async_set("light.permitted", "on", {"color": "red"})
data = await websocket_client.receive_str()
msg = json_loads(data)
assert msg["id"] == 1
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
await websocket_client.send_json({"id": 7, "type": "subscribe_entities"})
data = await websocket_client.receive_str()
msgs = json_loads(data)
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"a": {
"light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"}
}
}
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
hass.states.async_set("light.permitted", "on", {"color": "green"})
hass.states.async_set("light.permitted", "on", {"color": "blue"})
data = await websocket_client.receive_str()
msgs = json_loads(data)
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {"light.permitted": {"+": {"a": {"color": "yellow"}, "c": ANY, "lu": ANY}}}
}
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {"light.permitted": {"+": {"a": {"color": "green"}, "c": ANY, "lu": ANY}}}
}
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}}
}
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
hass.states.async_set("light.permitted", "on", {"color": "green"})
hass.states.async_set("light.permitted", "on", {"color": "blue"})
await websocket_client.close()
await hass.async_block_till_done()
async def test_message_coalescing_not_supported_by_websocket_client(
hass, websocket_client, hass_admin_user
):
"""Test enabling message coalescing not supported by websocket client."""
await websocket_client.send_json({"id": 7, "type": "subscribe_entities"})
data = await websocket_client.receive_str()
msg = json_loads(data)
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
hass.states.async_set("light.permitted", "on", {"color": "red"})
hass.states.async_set("light.permitted", "on", {"color": "blue"})
data = await websocket_client.receive_str()
msg = json_loads(data)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {"a": {}}
data = await websocket_client.receive_str()
msg = json_loads(data)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"a": {
"light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"}
}
}
data = await websocket_client.receive_str()
msg = json_loads(data)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}}
}
await websocket_client.close()
await hass.async_block_till_done()
async def test_client_message_coalescing(hass, websocket_client, hass_admin_user):
"""Test client message coalescing."""
await websocket_client.send_json(
[
{
"id": 1,
"type": "supported_features",
"features": {FEATURE_COALESCE_MESSAGES: 1},
},
{"id": 7, "type": "subscribe_entities"},
]
)
hass.states.async_set("light.permitted", "on", {"color": "red"})
data = await websocket_client.receive_str()
msgs = json_loads(data)
msg = msgs.pop(0)
assert msg["id"] == 1
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"a": {
"light.permitted": {"a": {"color": "red"}, "c": ANY, "lc": ANY, "s": "on"}
}
}
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
hass.states.async_set("light.permitted", "on", {"color": "green"})
hass.states.async_set("light.permitted", "on", {"color": "blue"})
data = await websocket_client.receive_str()
msgs = json_loads(data)
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {"light.permitted": {"+": {"a": {"color": "yellow"}, "c": ANY, "lu": ANY}}}
}
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {"light.permitted": {"+": {"a": {"color": "green"}, "c": ANY, "lu": ANY}}}
}
msg = msgs.pop(0)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"] == {
"c": {"light.permitted": {"+": {"a": {"color": "blue"}, "c": ANY, "lu": ANY}}}
}
hass.states.async_set("light.permitted", "on", {"color": "yellow"})
hass.states.async_set("light.permitted", "on", {"color": "green"})
hass.states.async_set("light.permitted", "on", {"color": "blue"})
await websocket_client.close()
await hass.async_block_till_done()
......@@ -2,15 +2,25 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable, Generator
from contextlib import asynccontextmanager
import functools
from json import JSONDecoder, loads
import logging
import ssl
import threading
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from aiohttp.test_utils import make_mocked_request
from aiohttp import client
from aiohttp.pytest_plugin import AiohttpClient
from aiohttp.test_utils import (
BaseTestServer,
TestClient,
TestServer,
make_mocked_request,
)
from aiohttp.web import Application
import freezegun
import multidict
import pytest
......@@ -57,6 +67,7 @@ from tests.components.recorder.common import ( # noqa: E402, isort:skip
async_recorder_block_till_done,
)
_LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
......@@ -203,6 +214,97 @@ def load_registries():
return True
class CoalescingResponse(client.ClientWebSocketResponse):
"""ClientWebSocketResponse client that mimics the websocket js code."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Init the ClientWebSocketResponse."""
super().__init__(*args, **kwargs)
self._recv_buffer: list[Any] = []
async def receive_json(
self,
*,
loads: JSONDecoder = loads,
timeout: float | None = None,
) -> Any:
"""receive_json or from buffer."""
if self._recv_buffer:
return self._recv_buffer.pop(0)
data = await self.receive_str(timeout=timeout)
decoded = loads(data)
if isinstance(decoded, list):
self._recv_buffer = decoded
return self._recv_buffer.pop(0)
return decoded
class CoalescingClient(TestClient):
"""Client that mimics the websocket js code."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Init TestClient."""
super().__init__(*args, ws_response_class=CoalescingResponse, **kwargs)
@pytest.fixture
def aiohttp_client_cls():
"""Override the test class for aiohttp."""
return CoalescingClient
@pytest.fixture
def aiohttp_client(
loop: asyncio.AbstractEventLoop,
) -> Generator[AiohttpClient, None, None]:
"""Override the default aiohttp_client since 3.x does not support aiohttp_client_cls.
Remove this when upgrading to 4.x as aiohttp_client_cls
will do the same thing
aiohttp_client(app, **kwargs)
aiohttp_client(server, **kwargs)
aiohttp_client(raw_server, **kwargs)
"""
clients = []
async def go(
__param: Application | BaseTestServer,
*args: Any,
server_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> TestClient:
if isinstance(__param, Callable) and not isinstance( # type: ignore[arg-type]
__param, (Application, BaseTestServer)
):
__param = __param(loop, *args, **kwargs)
kwargs = {}
else:
assert not args, "args should be empty"
if isinstance(__param, Application):
server_kwargs = server_kwargs or {}
server = TestServer(__param, loop=loop, **server_kwargs)
client = CoalescingClient(server, loop=loop, **kwargs)
elif isinstance(__param, BaseTestServer):
client = TestClient(__param, loop=loop, **kwargs)
else:
raise ValueError("Unknown argument type: %r" % type(__param))
await client.start_server()
clients.append(client)
return client
yield go
async def finalize() -> None:
while clients:
await clients.pop().close()
loop.run_until_complete(finalize())
@pytest.fixture
def hass(loop, load_registries, hass_storage, request):
"""Fixture to provide a test instance of Home Assistant."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment