From b1ed1543c8d17c2e1c936ae80a0671f7d35e4c99 Mon Sep 17 00:00:00 2001
From: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
Date: Thu, 21 Jul 2022 13:07:42 +0200
Subject: [PATCH] Improve http decorator typing (#75541)

---
 homeassistant/components/auth/login_flow.py   |  2 +-
 homeassistant/components/http/ban.py          | 18 ++++++----
 .../components/http/data_validator.py         | 33 ++++++++++++-------
 .../components/repairs/websocket_api.py       |  4 +--
 homeassistant/helpers/data_entry_flow.py      |  2 +-
 5 files changed, 37 insertions(+), 22 deletions(-)

diff --git a/homeassistant/components/auth/login_flow.py b/homeassistant/components/auth/login_flow.py
index b24da92afdd..6cc9d94c7a6 100644
--- a/homeassistant/components/auth/login_flow.py
+++ b/homeassistant/components/auth/login_flow.py
@@ -257,7 +257,7 @@ class LoginFlowResourceView(LoginFlowBaseView):
 
     @RequestDataValidator(vol.Schema({"client_id": str}, extra=vol.ALLOW_EXTRA))
     @log_invalid_auth
-    async def post(self, request, flow_id, data):
+    async def post(self, request, data, flow_id):
         """Handle progressing a login flow request."""
         client_id = data.pop("client_id")
 
diff --git a/homeassistant/components/http/ban.py b/homeassistant/components/http/ban.py
index ee8324b2791..d2f5f9d8ba5 100644
--- a/homeassistant/components/http/ban.py
+++ b/homeassistant/components/http/ban.py
@@ -2,17 +2,18 @@
 from __future__ import annotations
 
 from collections import defaultdict
-from collections.abc import Awaitable, Callable
+from collections.abc import Awaitable, Callable, Coroutine
 from contextlib import suppress
 from datetime import datetime
 from http import HTTPStatus
 from ipaddress import IPv4Address, IPv6Address, ip_address
 import logging
 from socket import gethostbyaddr, herror
-from typing import Any, Final
+from typing import Any, Final, TypeVar
 
-from aiohttp.web import Application, Request, StreamResponse, middleware
+from aiohttp.web import Application, Request, Response, StreamResponse, middleware
 from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
+from typing_extensions import Concatenate, ParamSpec
 import voluptuous as vol
 
 from homeassistant.components import persistent_notification
@@ -24,6 +25,9 @@ from homeassistant.util import dt as dt_util, yaml
 
 from .view import HomeAssistantView
 
+_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
+_P = ParamSpec("_P")
+
 _LOGGER: Final = logging.getLogger(__name__)
 
 KEY_BAN_MANAGER: Final = "ha_banned_ips_manager"
@@ -82,13 +86,13 @@ async def ban_middleware(
 
 
 def log_invalid_auth(
-    func: Callable[..., Awaitable[StreamResponse]]
-) -> Callable[..., Awaitable[StreamResponse]]:
+    func: Callable[Concatenate[_HassViewT, Request, _P], Awaitable[Response]]
+) -> Callable[Concatenate[_HassViewT, Request, _P], Coroutine[Any, Any, Response]]:
     """Decorate function to handle invalid auth or failed login attempts."""
 
     async def handle_req(
-        view: HomeAssistantView, request: Request, *args: Any, **kwargs: Any
-    ) -> StreamResponse:
+        view: _HassViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs
+    ) -> Response:
         """Try to log failed login attempts if response status >= BAD_REQUEST."""
         resp = await func(view, request, *args, **kwargs)
         if resp.status >= HTTPStatus.BAD_REQUEST:
diff --git a/homeassistant/components/http/data_validator.py b/homeassistant/components/http/data_validator.py
index cc661d43fd8..6647a6436c5 100644
--- a/homeassistant/components/http/data_validator.py
+++ b/homeassistant/components/http/data_validator.py
@@ -1,17 +1,21 @@
 """Decorator for view methods to help with data validation."""
 from __future__ import annotations
 
-from collections.abc import Awaitable, Callable
+from collections.abc import Awaitable, Callable, Coroutine
 from functools import wraps
 from http import HTTPStatus
 import logging
-from typing import Any
+from typing import Any, TypeVar
 
 from aiohttp import web
+from typing_extensions import Concatenate, ParamSpec
 import voluptuous as vol
 
 from .view import HomeAssistantView
 
+_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
+_P = ParamSpec("_P")
+
 _LOGGER = logging.getLogger(__name__)
 
 
@@ -33,33 +37,40 @@ class RequestDataValidator:
         self._allow_empty = allow_empty
 
     def __call__(
-        self, method: Callable[..., Awaitable[web.StreamResponse]]
-    ) -> Callable:
+        self,
+        method: Callable[
+            Concatenate[_HassViewT, web.Request, dict[str, Any], _P],
+            Awaitable[web.Response],
+        ],
+    ) -> Callable[
+        Concatenate[_HassViewT, web.Request, _P],
+        Coroutine[Any, Any, web.Response],
+    ]:
         """Decorate a function."""
 
         @wraps(method)
         async def wrapper(
-            view: HomeAssistantView, request: web.Request, *args: Any, **kwargs: Any
-        ) -> web.StreamResponse:
+            view: _HassViewT, request: web.Request, *args: _P.args, **kwargs: _P.kwargs
+        ) -> web.Response:
             """Wrap a request handler with data validation."""
-            data = None
+            raw_data = None
             try:
-                data = await request.json()
+                raw_data = await request.json()
             except ValueError:
                 if not self._allow_empty or (await request.content.read()) != b"":
                     _LOGGER.error("Invalid JSON received")
                     return view.json_message("Invalid JSON.", HTTPStatus.BAD_REQUEST)
-                data = {}
+                raw_data = {}
 
             try:
-                kwargs["data"] = self._schema(data)
+                data: dict[str, Any] = self._schema(raw_data)
             except vol.Invalid as err:
                 _LOGGER.error("Data does not match schema: %s", err)
                 return view.json_message(
                     f"Message format incorrect: {err}", HTTPStatus.BAD_REQUEST
                 )
 
-            result = await method(view, request, *args, **kwargs)
+            result = await method(view, request, data, *args, **kwargs)
             return result
 
         return wrapper
diff --git a/homeassistant/components/repairs/websocket_api.py b/homeassistant/components/repairs/websocket_api.py
index b6a71773273..2e9fcc5f8e4 100644
--- a/homeassistant/components/repairs/websocket_api.py
+++ b/homeassistant/components/repairs/websocket_api.py
@@ -113,7 +113,7 @@ class RepairsFlowIndexView(FlowManagerIndexView):
 
         result = self._prepare_result_json(result)
 
-        return self.json(result)  # pylint: disable=arguments-differ
+        return self.json(result)
 
 
 class RepairsFlowResourceView(FlowManagerResourceView):
@@ -136,4 +136,4 @@ class RepairsFlowResourceView(FlowManagerResourceView):
             raise Unauthorized(permission=POLICY_EDIT)
 
         # pylint: disable=no-value-for-parameter
-        return await super().post(request, flow_id)  # type: ignore[no-any-return]
+        return await super().post(request, flow_id)
diff --git a/homeassistant/helpers/data_entry_flow.py b/homeassistant/helpers/data_entry_flow.py
index 444876a7674..428a62f0c9d 100644
--- a/homeassistant/helpers/data_entry_flow.py
+++ b/homeassistant/helpers/data_entry_flow.py
@@ -102,7 +102,7 @@ class FlowManagerResourceView(_BaseFlowManagerView):
 
     @RequestDataValidator(vol.Schema(dict), allow_empty=True)
     async def post(
-        self, request: web.Request, flow_id: str, data: dict[str, Any]
+        self, request: web.Request, data: dict[str, Any], flow_id: str
     ) -> web.Response:
         """Handle a POST request."""
         try:
-- 
GitLab