From 533442f33e8c8419f40928eb242ea056c0bc1ebc Mon Sep 17 00:00:00 2001 From: Erik Montnemery <erik@montnemery.com> Date: Sat, 17 Aug 2024 11:01:49 +0200 Subject: [PATCH] Add async friendly helper for validating config schemas (#123800) * Add async friendly helper for validating config schemas * Improve docstrings * Add tests --- homeassistant/config.py | 8 +- homeassistant/helpers/check_config.py | 6 +- homeassistant/helpers/config_validation.py | 85 +++++++++++++++++++++- tests/helpers/test_config_validation.py | 69 +++++++++++++++++- 4 files changed, 161 insertions(+), 7 deletions(-) diff --git a/homeassistant/config.py b/homeassistant/config.py index 948ab342e79..5066dffab47 100644 --- a/homeassistant/config.py +++ b/homeassistant/config.py @@ -1535,7 +1535,9 @@ async def async_process_component_config( # No custom config validator, proceed with schema validation if hasattr(component, "CONFIG_SCHEMA"): try: - return IntegrationConfigInfo(component.CONFIG_SCHEMA(config), []) + return IntegrationConfigInfo( + await cv.async_validate(hass, component.CONFIG_SCHEMA, config), [] + ) except vol.Invalid as exc: exc_info = ConfigExceptionInfo( exc, @@ -1570,7 +1572,9 @@ async def async_process_component_config( # Validate component specific platform schema platform_path = f"{p_name}.{domain}" try: - p_validated = component_platform_schema(p_config) + p_validated = await cv.async_validate( + hass, component_platform_schema, p_config + ) except vol.Invalid as exc: exc_info = ConfigExceptionInfo( exc, diff --git a/homeassistant/helpers/check_config.py b/homeassistant/helpers/check_config.py index 06d836e8c20..43021fffac5 100644 --- a/homeassistant/helpers/check_config.py +++ b/homeassistant/helpers/check_config.py @@ -234,7 +234,7 @@ async def async_check_ha_config_file( # noqa: C901 config_schema = getattr(component, "CONFIG_SCHEMA", None) if config_schema is not None: try: - validated_config = config_schema(config) + validated_config = await cv.async_validate(hass, config_schema, config) # Don't fail if the validator removed the domain from the config if domain in validated_config: result[domain] = validated_config[domain] @@ -255,7 +255,9 @@ async def async_check_ha_config_file( # noqa: C901 for p_name, p_config in config_per_platform(config, domain): # Validate component specific platform schema try: - p_validated = component_platform_schema(p_config) + p_validated = await cv.async_validate( + hass, component_platform_schema, p_config + ) except vol.Invalid as ex: _comp_error(ex, domain, p_config, p_config) continue diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 6e9a6d5a69d..d7a5d5ae8a1 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -6,6 +6,7 @@ from collections.abc import Callable, Hashable import contextlib +from contextvars import ContextVar from datetime import ( date as date_sys, datetime as datetime_sys, @@ -13,6 +14,7 @@ from datetime import ( timedelta, ) from enum import Enum, StrEnum +import functools import logging from numbers import Number import os @@ -20,6 +22,7 @@ import re from socket import ( # type: ignore[attr-defined] # private, not in typeshed _GLOBAL_DEFAULT_TIMEOUT, ) +import threading from typing import Any, cast, overload from urllib.parse import urlparse from uuid import UUID @@ -94,6 +97,7 @@ from homeassistant.const import ( ) from homeassistant.core import ( DOMAIN as HOMEASSISTANT_DOMAIN, + HomeAssistant, async_get_hass, async_get_hass_or_none, split_entity_id, @@ -114,6 +118,51 @@ from .typing import VolDictType, VolSchemaType TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM', 'HH:MM:SS' or 'HH:MM:SS.F'" +class MustValidateInExecutor(HomeAssistantError): + """Raised when validation must happen in an executor thread.""" + + +class _Hass(threading.local): + """Container which makes a HomeAssistant instance available to validators.""" + + hass: HomeAssistant | None = None + + +_hass = _Hass() +"""Set when doing async friendly schema validation.""" + + +def _async_get_hass_or_none() -> HomeAssistant | None: + """Return the HomeAssistant instance or None. + + First tries core.async_get_hass_or_none, then _hass which is + set when doing async friendly schema validation. + """ + return async_get_hass_or_none() or _hass.hass + + +_validating_async: ContextVar[bool] = ContextVar("_validating_async", default=False) +"""Set to True when doing async friendly schema validation.""" + + +def not_async_friendly[**_P, _R](validator: Callable[_P, _R]) -> Callable[_P, _R]: + """Mark a validator as not async friendly. + + This makes validation happen in an executor thread if validation is done by + async_validate, otherwise does nothing. + """ + + @functools.wraps(validator) + def _not_async_friendly(*args: _P.args, **kwargs: _P.kwargs) -> _R: + if _validating_async.get() and async_get_hass_or_none(): + # Raise if doing async friendly validation and validation + # is happening in the event loop + raise MustValidateInExecutor + return validator(*args, **kwargs) + + return _not_async_friendly + + class UrlProtocolSchema(StrEnum): """Valid URL protocol schema values.""" @@ -217,6 +266,7 @@ def whitespace(value: Any) -> str: raise vol.Invalid(f"contains non-whitespace: {value}") +@not_async_friendly def isdevice(value: Any) -> str: """Validate that value is a real device.""" try: @@ -258,6 +308,7 @@ def is_regex(value: Any) -> re.Pattern[Any]: return r +@not_async_friendly def isfile(value: Any) -> str: """Validate that the value is an existing file.""" if value is None: @@ -271,6 +322,7 @@ def isfile(value: Any) -> str: return file_in +@not_async_friendly def isdir(value: Any) -> str: """Validate that the value is an existing dir.""" if value is None: @@ -664,7 +716,7 @@ def template(value: Any | None) -> template_helper.Template: if isinstance(value, (list, dict, template_helper.Template)): raise vol.Invalid("template value should be a string") - template_value = template_helper.Template(str(value), async_get_hass_or_none()) + template_value = template_helper.Template(str(value), _async_get_hass_or_none()) try: template_value.ensure_valid() @@ -682,7 +734,7 @@ def dynamic_template(value: Any | None) -> template_helper.Template: if not template_helper.is_template_string(str(value)): raise vol.Invalid("template value does not contain a dynamic template") - template_value = template_helper.Template(str(value), async_get_hass_or_none()) + template_value = template_helper.Template(str(value), _async_get_hass_or_none()) try: template_value.ensure_valid() @@ -1918,3 +1970,32 @@ historic_currency = vol.In( country = vol.In(COUNTRIES, msg="invalid ISO 3166 formatted country") language = vol.In(LANGUAGES, msg="invalid RFC 5646 formatted language") + + +async def async_validate( + hass: HomeAssistant, validator: Callable[[Any], Any], value: Any +) -> Any: + """Async friendly schema validation. + + If a validator decorated with @not_async_friendly is called, validation will be + deferred to an executor. If not, validation will happen in the event loop. + """ + _validating_async.set(True) + try: + return validator(value) + except MustValidateInExecutor: + return await hass.async_add_executor_job( + _validate_in_executor, hass, validator, value + ) + finally: + _validating_async.set(False) + + +def _validate_in_executor( + hass: HomeAssistant, validator: Callable[[Any], Any], value: Any +) -> Any: + _hass.hass = hass + try: + return validator(value) + finally: + _hass.hass = None diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py index ac3af13949b..973f504df08 100644 --- a/tests/helpers/test_config_validation.py +++ b/tests/helpers/test_config_validation.py @@ -3,13 +3,16 @@ from collections import OrderedDict from datetime import date, datetime, timedelta import enum +from functools import partial import logging import os from socket import _GLOBAL_DEFAULT_TIMEOUT +import threading from typing import Any -from unittest.mock import Mock, patch +from unittest.mock import ANY, Mock, patch import uuid +import py import pytest import voluptuous as vol @@ -1738,3 +1741,67 @@ def test_determine_script_action_ambiguous() -> None: def test_determine_script_action_non_ambiguous() -> None: """Test determine script action with a non ambiguous action.""" assert cv.determine_script_action({"delay": "00:00:05"}) == "delay" + + +async def test_async_validate(hass: HomeAssistant, tmpdir: py.path.local) -> None: + """Test the async_validate helper.""" + validator_calls: dict[str, list[int]] = {} + + def _mock_validator_schema(real_func, *args): + calls = validator_calls.setdefault(real_func.__name__, []) + calls.append(threading.get_ident()) + return real_func(*args) + + CV_PREFIX = "homeassistant.helpers.config_validation" + with ( + patch(f"{CV_PREFIX}.isdir", wraps=partial(_mock_validator_schema, cv.isdir)), + patch(f"{CV_PREFIX}.string", wraps=partial(_mock_validator_schema, cv.string)), + ): + # Assert validation in event loop when not decorated with not_async_friendly + await cv.async_validate(hass, cv.string, "abcd") + assert validator_calls == {"string": [hass.loop_thread_id]} + validator_calls = {} + + # Assert validation in executor when decorated with not_async_friendly + await cv.async_validate(hass, cv.isdir, tmpdir) + assert validator_calls == {"isdir": [hass.loop_thread_id, ANY]} + assert validator_calls["isdir"][1] != hass.loop_thread_id + validator_calls = {} + + # Assert validation in executor when decorated with not_async_friendly + await cv.async_validate(hass, vol.All(cv.isdir, cv.string), tmpdir) + assert validator_calls == {"isdir": [hass.loop_thread_id, ANY], "string": [ANY]} + assert validator_calls["isdir"][1] != hass.loop_thread_id + assert validator_calls["string"][0] != hass.loop_thread_id + validator_calls = {} + + # Assert validation in executor when decorated with not_async_friendly + await cv.async_validate(hass, vol.All(cv.string, cv.isdir), tmpdir) + assert validator_calls == { + "isdir": [hass.loop_thread_id, ANY], + "string": [hass.loop_thread_id, ANY], + } + assert validator_calls["isdir"][1] != hass.loop_thread_id + assert validator_calls["string"][1] != hass.loop_thread_id + validator_calls = {} + + # Assert validation in event loop when not using cv.async_validate + cv.isdir(tmpdir) + assert validator_calls == {"isdir": [hass.loop_thread_id]} + validator_calls = {} + + # Assert validation in event loop when not using cv.async_validate + vol.All(cv.isdir, cv.string)(tmpdir) + assert validator_calls == { + "isdir": [hass.loop_thread_id], + "string": [hass.loop_thread_id], + } + validator_calls = {} + + # Assert validation in event loop when not using cv.async_validate + vol.All(cv.string, cv.isdir)(tmpdir) + assert validator_calls == { + "isdir": [hass.loop_thread_id], + "string": [hass.loop_thread_id], + } + validator_calls = {} -- GitLab