Skip to content
Snippets Groups Projects
Unverified Commit b1ed1543 authored by Marc Mueller's avatar Marc Mueller Committed by GitHub
Browse files

Improve http decorator typing (#75541)

parent 1d7d2875
No related branches found
No related tags found
No related merge requests found
......@@ -257,7 +257,7 @@ class LoginFlowResourceView(LoginFlowBaseView):
@RequestDataValidator(vol.Schema({"client_id": str}, extra=vol.ALLOW_EXTRA))
@log_invalid_auth
async def post(self, request, flow_id, data):
async def post(self, request, data, flow_id):
"""Handle progressing a login flow request."""
client_id = data.pop("client_id")
......
......@@ -2,17 +2,18 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Coroutine
from contextlib import suppress
from datetime import datetime
from http import HTTPStatus
from ipaddress import IPv4Address, IPv6Address, ip_address
import logging
from socket import gethostbyaddr, herror
from typing import Any, Final
from typing import Any, Final, TypeVar
from aiohttp.web import Application, Request, StreamResponse, middleware
from aiohttp.web import Application, Request, Response, StreamResponse, middleware
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
from typing_extensions import Concatenate, ParamSpec
import voluptuous as vol
from homeassistant.components import persistent_notification
......@@ -24,6 +25,9 @@ from homeassistant.util import dt as dt_util, yaml
from .view import HomeAssistantView
_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
_P = ParamSpec("_P")
_LOGGER: Final = logging.getLogger(__name__)
KEY_BAN_MANAGER: Final = "ha_banned_ips_manager"
......@@ -82,13 +86,13 @@ async def ban_middleware(
def log_invalid_auth(
func: Callable[..., Awaitable[StreamResponse]]
) -> Callable[..., Awaitable[StreamResponse]]:
func: Callable[Concatenate[_HassViewT, Request, _P], Awaitable[Response]]
) -> Callable[Concatenate[_HassViewT, Request, _P], Coroutine[Any, Any, Response]]:
"""Decorate function to handle invalid auth or failed login attempts."""
async def handle_req(
view: HomeAssistantView, request: Request, *args: Any, **kwargs: Any
) -> StreamResponse:
view: _HassViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs
) -> Response:
"""Try to log failed login attempts if response status >= BAD_REQUEST."""
resp = await func(view, request, *args, **kwargs)
if resp.status >= HTTPStatus.BAD_REQUEST:
......
"""Decorator for view methods to help with data validation."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Coroutine
from functools import wraps
from http import HTTPStatus
import logging
from typing import Any
from typing import Any, TypeVar
from aiohttp import web
from typing_extensions import Concatenate, ParamSpec
import voluptuous as vol
from .view import HomeAssistantView
_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
_P = ParamSpec("_P")
_LOGGER = logging.getLogger(__name__)
......@@ -33,33 +37,40 @@ class RequestDataValidator:
self._allow_empty = allow_empty
def __call__(
self, method: Callable[..., Awaitable[web.StreamResponse]]
) -> Callable:
self,
method: Callable[
Concatenate[_HassViewT, web.Request, dict[str, Any], _P],
Awaitable[web.Response],
],
) -> Callable[
Concatenate[_HassViewT, web.Request, _P],
Coroutine[Any, Any, web.Response],
]:
"""Decorate a function."""
@wraps(method)
async def wrapper(
view: HomeAssistantView, request: web.Request, *args: Any, **kwargs: Any
) -> web.StreamResponse:
view: _HassViewT, request: web.Request, *args: _P.args, **kwargs: _P.kwargs
) -> web.Response:
"""Wrap a request handler with data validation."""
data = None
raw_data = None
try:
data = await request.json()
raw_data = await request.json()
except ValueError:
if not self._allow_empty or (await request.content.read()) != b"":
_LOGGER.error("Invalid JSON received")
return view.json_message("Invalid JSON.", HTTPStatus.BAD_REQUEST)
data = {}
raw_data = {}
try:
kwargs["data"] = self._schema(data)
data: dict[str, Any] = self._schema(raw_data)
except vol.Invalid as err:
_LOGGER.error("Data does not match schema: %s", err)
return view.json_message(
f"Message format incorrect: {err}", HTTPStatus.BAD_REQUEST
)
result = await method(view, request, *args, **kwargs)
result = await method(view, request, data, *args, **kwargs)
return result
return wrapper
......@@ -113,7 +113,7 @@ class RepairsFlowIndexView(FlowManagerIndexView):
result = self._prepare_result_json(result)
return self.json(result) # pylint: disable=arguments-differ
return self.json(result)
class RepairsFlowResourceView(FlowManagerResourceView):
......@@ -136,4 +136,4 @@ class RepairsFlowResourceView(FlowManagerResourceView):
raise Unauthorized(permission=POLICY_EDIT)
# pylint: disable=no-value-for-parameter
return await super().post(request, flow_id) # type: ignore[no-any-return]
return await super().post(request, flow_id)
......@@ -102,7 +102,7 @@ class FlowManagerResourceView(_BaseFlowManagerView):
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
async def post(
self, request: web.Request, flow_id: str, data: dict[str, Any]
self, request: web.Request, data: dict[str, Any], flow_id: str
) -> web.Response:
"""Handle a POST request."""
try:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment