From eb8f8e1ae46c5eb02f1f117dd3bc134a2cade1d7 Mon Sep 17 00:00:00 2001
From: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
Date: Fri, 8 Mar 2024 11:13:24 +0100
Subject: [PATCH] Use aiohttp.AppKey for http ban keys (#112657)

---
 homeassistant/components/http/ban.py | 31 ++++++++++++++++------------
 tests/components/http/test_ban.py    |  7 +++----
 2 files changed, 21 insertions(+), 17 deletions(-)

diff --git a/homeassistant/components/http/ban.py b/homeassistant/components/http/ban.py
index bec11d6e5ff..e5a65c2fe72 100644
--- a/homeassistant/components/http/ban.py
+++ b/homeassistant/components/http/ban.py
@@ -11,7 +11,14 @@ import logging
 from socket import gethostbyaddr, herror
 from typing import Any, Concatenate, Final, ParamSpec, TypeVar
 
-from aiohttp.web import Application, Request, Response, StreamResponse, middleware
+from aiohttp.web import (
+    AppKey,
+    Application,
+    Request,
+    Response,
+    StreamResponse,
+    middleware,
+)
 from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
 import voluptuous as vol
 
@@ -29,9 +36,11 @@ _P = ParamSpec("_P")
 
 _LOGGER: Final = logging.getLogger(__name__)
 
-KEY_BAN_MANAGER: Final = "ha_banned_ips_manager"
-KEY_FAILED_LOGIN_ATTEMPTS: Final = "ha_failed_login_attempts"
-KEY_LOGIN_THRESHOLD: Final = "ha_login_threshold"
+KEY_BAN_MANAGER = AppKey["IpBanManager"]("ha_banned_ips_manager")
+KEY_FAILED_LOGIN_ATTEMPTS = AppKey[defaultdict[IPv4Address | IPv6Address, int]](
+    "ha_failed_login_attempts"
+)
+KEY_LOGIN_THRESHOLD = AppKey[int]("ban_manager.ip_bans_lookup")
 
 NOTIFICATION_ID_BAN: Final = "ip-ban"
 NOTIFICATION_ID_LOGIN: Final = "http-login"
@@ -48,7 +57,7 @@ SCHEMA_IP_BAN_ENTRY: Final = vol.Schema(
 def setup_bans(hass: HomeAssistant, app: Application, login_threshold: int) -> None:
     """Create IP Ban middleware for the app."""
     app.middlewares.append(ban_middleware)
-    app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
+    app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict[IPv4Address | IPv6Address, int](int)
     app[KEY_LOGIN_THRESHOLD] = login_threshold
     app[KEY_BAN_MANAGER] = IpBanManager(hass)
 
@@ -64,13 +73,11 @@ async def ban_middleware(
     request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
 ) -> StreamResponse:
     """IP Ban middleware."""
-    ban_manager: IpBanManager | None = request.app.get(KEY_BAN_MANAGER)
-    if ban_manager is None:
+    if (ban_manager := request.app.get(KEY_BAN_MANAGER)) is None:
         _LOGGER.error("IP Ban middleware loaded but banned IPs not loaded")
         return await handler(request)
 
-    ip_bans_lookup = ban_manager.ip_bans_lookup
-    if ip_bans_lookup:
+    if ip_bans_lookup := ban_manager.ip_bans_lookup:
         # Verify if IP is not banned
         ip_address_ = ip_address(request.remote)  # type: ignore[arg-type]
         if ip_address_ in ip_bans_lookup:
@@ -154,7 +161,7 @@ async def process_wrong_login(request: Request) -> None:
         request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr]
         >= request.app[KEY_LOGIN_THRESHOLD]
     ):
-        ban_manager: IpBanManager = request.app[KEY_BAN_MANAGER]
+        ban_manager = request.app[KEY_BAN_MANAGER]
         _LOGGER.warning("Banned IP %s for too many login attempts", remote_addr)
         await ban_manager.async_add_ban(remote_addr)
 
@@ -180,9 +187,7 @@ def process_success_login(request: Request) -> None:
         return
 
     remote_addr = ip_address(request.remote)  # type: ignore[arg-type]
-    login_attempt_history: defaultdict[IPv4Address | IPv6Address, int] = app[
-        KEY_FAILED_LOGIN_ATTEMPTS
-    ]
+    login_attempt_history = app[KEY_FAILED_LOGIN_ATTEMPTS]
     if remote_addr in login_attempt_history and login_attempt_history[remote_addr] > 0:
         _LOGGER.debug(
             "Login success, reset failed login attempts counter from %s", remote_addr
diff --git a/tests/components/http/test_ban.py b/tests/components/http/test_ban.py
index 26301cf5b79..5ab9db4e64e 100644
--- a/tests/components/http/test_ban.py
+++ b/tests/components/http/test_ban.py
@@ -15,7 +15,6 @@ from homeassistant.components.http.ban import (
     IP_BANS_FILE,
     KEY_BAN_MANAGER,
     KEY_FAILED_LOGIN_ATTEMPTS,
-    IpBanManager,
     process_success_login,
     setup_bans,
 )
@@ -215,7 +214,7 @@ async def test_access_from_supervisor_ip(
     ):
         client = await aiohttp_client(app)
 
-    manager: IpBanManager = app[KEY_BAN_MANAGER]
+    manager = app[KEY_BAN_MANAGER]
 
     with patch(
         "homeassistant.components.hassio.HassIO.get_resolution_info",
@@ -288,7 +287,7 @@ async def test_ip_bans_file_creation(
     ):
         client = await aiohttp_client(app)
 
-    manager: IpBanManager = app[KEY_BAN_MANAGER]
+    manager = app[KEY_BAN_MANAGER]
     m_open = mock_open()
 
     with patch("homeassistant.components.http.ban.open", m_open, create=True):
@@ -408,7 +407,7 @@ async def test_single_ban_file_entry(
     setup_bans(hass, app, 2)
     mock_real_ip(app)("200.201.202.204")
 
-    manager: IpBanManager = app[KEY_BAN_MANAGER]
+    manager = app[KEY_BAN_MANAGER]
     m_open = mock_open()
 
     with patch("homeassistant.components.http.ban.open", m_open, create=True):
-- 
GitLab