diff --git a/homeassistant/components/websocket_api/auth.py b/homeassistant/components/websocket_api/auth.py index 0a681692c3d4e6a209ab925560a335f488788a8f..176b561f58320dc7b0060334d681162d3bfc88e6 100644 --- a/homeassistant/components/websocket_api/auth.py +++ b/homeassistant/components/websocket_api/auth.py @@ -1,14 +1,13 @@ """Handle the auth of a connection.""" from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any, Final from aiohttp.web import Request import voluptuous as vol from voluptuous.humanize import humanize_error -from homeassistant.auth.models import RefreshToken, User from homeassistant.components.http.ban import process_success_login, process_wrong_login from homeassistant.const import __version__ from homeassistant.core import CALLBACK_TYPE, HomeAssistant @@ -41,9 +40,9 @@ AUTH_REQUIRED_MESSAGE = json_bytes( ) -def auth_invalid_message(message: str) -> dict[str, str]: +def auth_invalid_message(message: str) -> bytes: """Return an auth_invalid message.""" - return {"type": TYPE_AUTH_INVALID, "message": message} + return json_bytes({"type": TYPE_AUTH_INVALID, "message": message}) class AuthPhase: @@ -56,13 +55,17 @@ class AuthPhase: send_message: Callable[[bytes | str | dict[str, Any]], None], cancel_ws: CALLBACK_TYPE, request: Request, + send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]], ) -> None: - """Initialize the authentiated connection.""" + """Initialize the authenticated connection.""" self._hass = hass + # send_message will send a message to the client via the queue. self._send_message = send_message self._cancel_ws = cancel_ws self._logger = logger self._request = request + # send_bytes_text will directly send a message to the client. + self._send_bytes_text = send_bytes_text async def async_handle(self, msg: JsonValueType) -> ActiveConnection: """Handle authentication.""" @@ -73,7 +76,7 @@ class AuthPhase: f"Auth message incorrectly formatted: {humanize_error(msg, err)}" ) self._logger.warning(error_msg) - self._send_message(auth_invalid_message(error_msg)) + await self._send_bytes_text(auth_invalid_message(error_msg)) raise Disconnect from err if (access_token := valid_msg.get("access_token")) and ( @@ -81,26 +84,25 @@ class AuthPhase: access_token ) ): - conn = await self._async_finish_auth(refresh_token.user, refresh_token) + conn = ActiveConnection( + self._logger, + self._hass, + self._send_message, + refresh_token.user, + refresh_token, + ) conn.subscriptions[ "auth" ] = self._hass.auth.async_register_revoke_token_callback( refresh_token.id, self._cancel_ws ) - + await self._send_bytes_text(AUTH_OK_MESSAGE) + self._logger.debug("Auth OK") + process_success_login(self._request) return conn - self._send_message(auth_invalid_message("Invalid access token or password")) + await self._send_bytes_text( + auth_invalid_message("Invalid access token or password") + ) await process_wrong_login(self._request) raise Disconnect - - async def _async_finish_auth( - self, user: User, refresh_token: RefreshToken - ) -> ActiveConnection: - """Create an active connection.""" - self._logger.debug("Auth OK") - process_success_login(self._request) - self._send_message(AUTH_OK_MESSAGE) - return ActiveConnection( - self._logger, self._hass, self._send_message, user, refresh_token - ) diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index d966e4e26ef505f966e3b4aaa7c0c71d8cd47a64..416573d493cd8b1082363bbe32279e0d1d6423aa 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio from collections import deque -from collections.abc import Callable +from collections.abc import Callable, Coroutine import datetime as dt from functools import partial import logging @@ -116,16 +116,14 @@ class WebSocketHandler: return describe_request(request) return "finished connection" - async def _writer(self) -> None: + async def _writer( + self, send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]] + ) -> None: """Write outgoing messages.""" # Variables are set locally to avoid lookups in the loop message_queue = self._message_queue logger = self._logger wsock = self._wsock - writer = wsock._writer # pylint: disable=protected-access - if TYPE_CHECKING: - assert writer is not None - send_str = partial(writer.send, binary=False) loop = self._hass.loop debug = logger.debug is_enabled_for = logger.isEnabledFor @@ -152,7 +150,7 @@ class WebSocketHandler: ): if debug_enabled: debug("%s: Sending %s", self.description, message) - await send_str(message) + await send_bytes_text(message) continue messages: list[bytes] = [message] @@ -166,7 +164,7 @@ class WebSocketHandler: coalesced_messages = b"".join((b"[", b",".join(messages), b"]")) if debug_enabled: debug("%s: Sending %s", self.description, coalesced_messages) - await send_str(coalesced_messages) + await send_bytes_text(coalesced_messages) except asyncio.CancelledError: debug("%s: Writer cancelled", self.description) raise @@ -186,7 +184,7 @@ class WebSocketHandler: @callback def _send_message(self, message: str | bytes | dict[str, Any]) -> None: - """Send a message to the client. + """Queue sending a message to the client. Closes connection if the client is not reading the messages. @@ -295,21 +293,23 @@ class WebSocketHandler: EVENT_HOMEASSISTANT_STOP, self._async_handle_hass_stop ) - # As the webserver is now started before the start - # event we do not want to block for websocket responses - self._writer_task = asyncio.create_task(self._writer()) + writer = wsock._writer # pylint: disable=protected-access + if TYPE_CHECKING: + assert writer is not None - auth = AuthPhase(logger, hass, self._send_message, self._cancel, request) + send_bytes_text = partial(writer.send, binary=False) + auth = AuthPhase( + logger, hass, self._send_message, self._cancel, request, send_bytes_text + ) connection = None disconnect_warn = None try: - self._send_message(AUTH_REQUIRED_MESSAGE) + await send_bytes_text(AUTH_REQUIRED_MESSAGE) # Auth Phase try: - async with asyncio.timeout(10): - msg = await wsock.receive() + msg = await wsock.receive(10) except asyncio.TimeoutError as err: disconnect_warn = "Did not receive auth message within 10 seconds" raise Disconnect from err @@ -330,7 +330,13 @@ class WebSocketHandler: if is_enabled_for(logging_debug): debug("%s: Received %s", self.description, auth_msg_data) connection = await auth.async_handle(auth_msg_data) + # As the webserver is now started before the start + # event we do not want to block for websocket responses + # + # We only start the writer queue after the auth phase is completed + # since there is no need to queue messages before the auth phase self._connection = connection + self._writer_task = asyncio.create_task(self._writer(send_bytes_text)) hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1 async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED) @@ -370,7 +376,7 @@ class WebSocketHandler: # added a way to set the limit, but there is no way to actually # reach the code to set the limit, so we have to set it directly. # - wsock._writer._limit = 2**20 # type: ignore[union-attr] # pylint: disable=protected-access + writer._limit = 2**20 # pylint: disable=protected-access async_handle_str = connection.async_handle async_handle_binary = connection.async_handle_binary @@ -441,7 +447,8 @@ class WebSocketHandler: # so we have another finally block to make sure we close the websocket # if the writer gets canceled. try: - await self._writer_task + if self._writer_task: + await self._writer_task finally: try: # Make sure all error messages are written before closing