From dbb5645e63b4cebe591102657b2b4efcfd7567a2 Mon Sep 17 00:00:00 2001
From: "J. Nick Koston" <nick@koston.org>
Date: Sun, 21 Jan 2024 17:33:31 -1000
Subject: [PATCH] Significantly reduce websocket api connection auth phase
 latency (#108564)

* Significantly reduce websocket api connection auth phase latancy

Since the auth phase has exclusive control over the websocket
until ActiveConnection is created, we can bypass the queue and
send messages right away. This reduces the latancy and reconnect
time since we do not have to wait for the background processing
of the queue to send the auth ok message.

* only start the writer queue after auth is successful
---
 .../components/websocket_api/auth.py          | 42 +++++++++---------
 .../components/websocket_api/http.py          | 43 +++++++++++--------
 2 files changed, 47 insertions(+), 38 deletions(-)

diff --git a/homeassistant/components/websocket_api/auth.py b/homeassistant/components/websocket_api/auth.py
index 0a681692c3d..176b561f583 100644
--- a/homeassistant/components/websocket_api/auth.py
+++ b/homeassistant/components/websocket_api/auth.py
@@ -1,14 +1,13 @@
 """Handle the auth of a connection."""
 from __future__ import annotations
 
-from collections.abc import Callable
+from collections.abc import Callable, Coroutine
 from typing import TYPE_CHECKING, Any, Final
 
 from aiohttp.web import Request
 import voluptuous as vol
 from voluptuous.humanize import humanize_error
 
-from homeassistant.auth.models import RefreshToken, User
 from homeassistant.components.http.ban import process_success_login, process_wrong_login
 from homeassistant.const import __version__
 from homeassistant.core import CALLBACK_TYPE, HomeAssistant
@@ -41,9 +40,9 @@ AUTH_REQUIRED_MESSAGE = json_bytes(
 )
 
 
-def auth_invalid_message(message: str) -> dict[str, str]:
+def auth_invalid_message(message: str) -> bytes:
     """Return an auth_invalid message."""
-    return {"type": TYPE_AUTH_INVALID, "message": message}
+    return json_bytes({"type": TYPE_AUTH_INVALID, "message": message})
 
 
 class AuthPhase:
@@ -56,13 +55,17 @@ class AuthPhase:
         send_message: Callable[[bytes | str | dict[str, Any]], None],
         cancel_ws: CALLBACK_TYPE,
         request: Request,
+        send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]],
     ) -> None:
-        """Initialize the authentiated connection."""
+        """Initialize the authenticated connection."""
         self._hass = hass
+        # send_message will send a message to the client via the queue.
         self._send_message = send_message
         self._cancel_ws = cancel_ws
         self._logger = logger
         self._request = request
+        # send_bytes_text will directly send a message to the client.
+        self._send_bytes_text = send_bytes_text
 
     async def async_handle(self, msg: JsonValueType) -> ActiveConnection:
         """Handle authentication."""
@@ -73,7 +76,7 @@ class AuthPhase:
                 f"Auth message incorrectly formatted: {humanize_error(msg, err)}"
             )
             self._logger.warning(error_msg)
-            self._send_message(auth_invalid_message(error_msg))
+            await self._send_bytes_text(auth_invalid_message(error_msg))
             raise Disconnect from err
 
         if (access_token := valid_msg.get("access_token")) and (
@@ -81,26 +84,25 @@ class AuthPhase:
                 access_token
             )
         ):
-            conn = await self._async_finish_auth(refresh_token.user, refresh_token)
+            conn = ActiveConnection(
+                self._logger,
+                self._hass,
+                self._send_message,
+                refresh_token.user,
+                refresh_token,
+            )
             conn.subscriptions[
                 "auth"
             ] = self._hass.auth.async_register_revoke_token_callback(
                 refresh_token.id, self._cancel_ws
             )
-
+            await self._send_bytes_text(AUTH_OK_MESSAGE)
+            self._logger.debug("Auth OK")
+            process_success_login(self._request)
             return conn
 
-        self._send_message(auth_invalid_message("Invalid access token or password"))
+        await self._send_bytes_text(
+            auth_invalid_message("Invalid access token or password")
+        )
         await process_wrong_login(self._request)
         raise Disconnect
-
-    async def _async_finish_auth(
-        self, user: User, refresh_token: RefreshToken
-    ) -> ActiveConnection:
-        """Create an active connection."""
-        self._logger.debug("Auth OK")
-        process_success_login(self._request)
-        self._send_message(AUTH_OK_MESSAGE)
-        return ActiveConnection(
-            self._logger, self._hass, self._send_message, user, refresh_token
-        )
diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py
index d966e4e26ef..416573d493c 100644
--- a/homeassistant/components/websocket_api/http.py
+++ b/homeassistant/components/websocket_api/http.py
@@ -3,7 +3,7 @@ from __future__ import annotations
 
 import asyncio
 from collections import deque
-from collections.abc import Callable
+from collections.abc import Callable, Coroutine
 import datetime as dt
 from functools import partial
 import logging
@@ -116,16 +116,14 @@ class WebSocketHandler:
             return describe_request(request)
         return "finished connection"
 
-    async def _writer(self) -> None:
+    async def _writer(
+        self, send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]]
+    ) -> None:
         """Write outgoing messages."""
         # Variables are set locally to avoid lookups in the loop
         message_queue = self._message_queue
         logger = self._logger
         wsock = self._wsock
-        writer = wsock._writer  # pylint: disable=protected-access
-        if TYPE_CHECKING:
-            assert writer is not None
-        send_str = partial(writer.send, binary=False)
         loop = self._hass.loop
         debug = logger.debug
         is_enabled_for = logger.isEnabledFor
@@ -152,7 +150,7 @@ class WebSocketHandler:
                 ):
                     if debug_enabled:
                         debug("%s: Sending %s", self.description, message)
-                    await send_str(message)
+                    await send_bytes_text(message)
                     continue
 
                 messages: list[bytes] = [message]
@@ -166,7 +164,7 @@ class WebSocketHandler:
                 coalesced_messages = b"".join((b"[", b",".join(messages), b"]"))
                 if debug_enabled:
                     debug("%s: Sending %s", self.description, coalesced_messages)
-                await send_str(coalesced_messages)
+                await send_bytes_text(coalesced_messages)
         except asyncio.CancelledError:
             debug("%s: Writer cancelled", self.description)
             raise
@@ -186,7 +184,7 @@ class WebSocketHandler:
 
     @callback
     def _send_message(self, message: str | bytes | dict[str, Any]) -> None:
-        """Send a message to the client.
+        """Queue sending a message to the client.
 
         Closes connection if the client is not reading the messages.
 
@@ -295,21 +293,23 @@ class WebSocketHandler:
             EVENT_HOMEASSISTANT_STOP, self._async_handle_hass_stop
         )
 
-        # As the webserver is now started before the start
-        # event we do not want to block for websocket responses
-        self._writer_task = asyncio.create_task(self._writer())
+        writer = wsock._writer  # pylint: disable=protected-access
+        if TYPE_CHECKING:
+            assert writer is not None
 
-        auth = AuthPhase(logger, hass, self._send_message, self._cancel, request)
+        send_bytes_text = partial(writer.send, binary=False)
+        auth = AuthPhase(
+            logger, hass, self._send_message, self._cancel, request, send_bytes_text
+        )
         connection = None
         disconnect_warn = None
 
         try:
-            self._send_message(AUTH_REQUIRED_MESSAGE)
+            await send_bytes_text(AUTH_REQUIRED_MESSAGE)
 
             # Auth Phase
             try:
-                async with asyncio.timeout(10):
-                    msg = await wsock.receive()
+                msg = await wsock.receive(10)
             except asyncio.TimeoutError as err:
                 disconnect_warn = "Did not receive auth message within 10 seconds"
                 raise Disconnect from err
@@ -330,7 +330,13 @@ class WebSocketHandler:
             if is_enabled_for(logging_debug):
                 debug("%s: Received %s", self.description, auth_msg_data)
             connection = await auth.async_handle(auth_msg_data)
+            # As the webserver is now started before the start
+            # event we do not want to block for websocket responses
+            #
+            # We only start the writer queue after the auth phase is completed
+            # since there is no need to queue messages before the auth phase
             self._connection = connection
+            self._writer_task = asyncio.create_task(self._writer(send_bytes_text))
             hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1
             async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED)
 
@@ -370,7 +376,7 @@ class WebSocketHandler:
             # added a way to set the limit, but there is no way to actually
             # reach the code to set the limit, so we have to set it directly.
             #
-            wsock._writer._limit = 2**20  # type: ignore[union-attr] # pylint: disable=protected-access
+            writer._limit = 2**20  # pylint: disable=protected-access
             async_handle_str = connection.async_handle
             async_handle_binary = connection.async_handle_binary
 
@@ -441,7 +447,8 @@ class WebSocketHandler:
             # so we have another finally block to make sure we close the websocket
             # if the writer gets canceled.
             try:
-                await self._writer_task
+                if self._writer_task:
+                    await self._writer_task
             finally:
                 try:
                     # Make sure all error messages are written before closing
-- 
GitLab