From 533442f33e8c8419f40928eb242ea056c0bc1ebc Mon Sep 17 00:00:00 2001
From: Erik Montnemery <erik@montnemery.com>
Date: Sat, 17 Aug 2024 11:01:49 +0200
Subject: [PATCH] Add async friendly helper for validating config schemas
 (#123800)

* Add async friendly helper for validating config schemas

* Improve docstrings

* Add tests
---
 homeassistant/config.py                    |  8 +-
 homeassistant/helpers/check_config.py      |  6 +-
 homeassistant/helpers/config_validation.py | 85 +++++++++++++++++++++-
 tests/helpers/test_config_validation.py    | 69 +++++++++++++++++-
 4 files changed, 161 insertions(+), 7 deletions(-)

diff --git a/homeassistant/config.py b/homeassistant/config.py
index 948ab342e79..5066dffab47 100644
--- a/homeassistant/config.py
+++ b/homeassistant/config.py
@@ -1535,7 +1535,9 @@ async def async_process_component_config(
     # No custom config validator, proceed with schema validation
     if hasattr(component, "CONFIG_SCHEMA"):
         try:
-            return IntegrationConfigInfo(component.CONFIG_SCHEMA(config), [])
+            return IntegrationConfigInfo(
+                await cv.async_validate(hass, component.CONFIG_SCHEMA, config), []
+            )
         except vol.Invalid as exc:
             exc_info = ConfigExceptionInfo(
                 exc,
@@ -1570,7 +1572,9 @@ async def async_process_component_config(
         # Validate component specific platform schema
         platform_path = f"{p_name}.{domain}"
         try:
-            p_validated = component_platform_schema(p_config)
+            p_validated = await cv.async_validate(
+                hass, component_platform_schema, p_config
+            )
         except vol.Invalid as exc:
             exc_info = ConfigExceptionInfo(
                 exc,
diff --git a/homeassistant/helpers/check_config.py b/homeassistant/helpers/check_config.py
index 06d836e8c20..43021fffac5 100644
--- a/homeassistant/helpers/check_config.py
+++ b/homeassistant/helpers/check_config.py
@@ -234,7 +234,7 @@ async def async_check_ha_config_file(  # noqa: C901
         config_schema = getattr(component, "CONFIG_SCHEMA", None)
         if config_schema is not None:
             try:
-                validated_config = config_schema(config)
+                validated_config = await cv.async_validate(hass, config_schema, config)
                 # Don't fail if the validator removed the domain from the config
                 if domain in validated_config:
                     result[domain] = validated_config[domain]
@@ -255,7 +255,9 @@ async def async_check_ha_config_file(  # noqa: C901
         for p_name, p_config in config_per_platform(config, domain):
             # Validate component specific platform schema
             try:
-                p_validated = component_platform_schema(p_config)
+                p_validated = await cv.async_validate(
+                    hass, component_platform_schema, p_config
+                )
             except vol.Invalid as ex:
                 _comp_error(ex, domain, p_config, p_config)
                 continue
diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py
index 6e9a6d5a69d..d7a5d5ae8a1 100644
--- a/homeassistant/helpers/config_validation.py
+++ b/homeassistant/helpers/config_validation.py
@@ -6,6 +6,7 @@
 
 from collections.abc import Callable, Hashable
 import contextlib
+from contextvars import ContextVar
 from datetime import (
     date as date_sys,
     datetime as datetime_sys,
@@ -13,6 +14,7 @@ from datetime import (
     timedelta,
 )
 from enum import Enum, StrEnum
+import functools
 import logging
 from numbers import Number
 import os
@@ -20,6 +22,7 @@ import re
 from socket import (  # type: ignore[attr-defined]  # private, not in typeshed
     _GLOBAL_DEFAULT_TIMEOUT,
 )
+import threading
 from typing import Any, cast, overload
 from urllib.parse import urlparse
 from uuid import UUID
@@ -94,6 +97,7 @@ from homeassistant.const import (
 )
 from homeassistant.core import (
     DOMAIN as HOMEASSISTANT_DOMAIN,
+    HomeAssistant,
     async_get_hass,
     async_get_hass_or_none,
     split_entity_id,
@@ -114,6 +118,51 @@ from .typing import VolDictType, VolSchemaType
 TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM', 'HH:MM:SS' or 'HH:MM:SS.F'"
 
 
+class MustValidateInExecutor(HomeAssistantError):
+    """Raised when validation must happen in an executor thread."""
+
+
+class _Hass(threading.local):
+    """Container which makes a HomeAssistant instance available to validators."""
+
+    hass: HomeAssistant | None = None
+
+
+_hass = _Hass()
+"""Set when doing async friendly schema validation."""
+
+
+def _async_get_hass_or_none() -> HomeAssistant | None:
+    """Return the HomeAssistant instance or None.
+
+    First tries core.async_get_hass_or_none, then _hass which is
+    set when doing async friendly schema validation.
+    """
+    return async_get_hass_or_none() or _hass.hass
+
+
+_validating_async: ContextVar[bool] = ContextVar("_validating_async", default=False)
+"""Set to True when doing async friendly schema validation."""
+
+
+def not_async_friendly[**_P, _R](validator: Callable[_P, _R]) -> Callable[_P, _R]:
+    """Mark a validator as not async friendly.
+
+    This makes validation happen in an executor thread if validation is done by
+    async_validate, otherwise does nothing.
+    """
+
+    @functools.wraps(validator)
+    def _not_async_friendly(*args: _P.args, **kwargs: _P.kwargs) -> _R:
+        if _validating_async.get() and async_get_hass_or_none():
+            # Raise if doing async friendly validation and validation
+            # is happening in the event loop
+            raise MustValidateInExecutor
+        return validator(*args, **kwargs)
+
+    return _not_async_friendly
+
+
 class UrlProtocolSchema(StrEnum):
     """Valid URL protocol schema values."""
 
@@ -217,6 +266,7 @@ def whitespace(value: Any) -> str:
     raise vol.Invalid(f"contains non-whitespace: {value}")
 
 
+@not_async_friendly
 def isdevice(value: Any) -> str:
     """Validate that value is a real device."""
     try:
@@ -258,6 +308,7 @@ def is_regex(value: Any) -> re.Pattern[Any]:
     return r
 
 
+@not_async_friendly
 def isfile(value: Any) -> str:
     """Validate that the value is an existing file."""
     if value is None:
@@ -271,6 +322,7 @@ def isfile(value: Any) -> str:
     return file_in
 
 
+@not_async_friendly
 def isdir(value: Any) -> str:
     """Validate that the value is an existing dir."""
     if value is None:
@@ -664,7 +716,7 @@ def template(value: Any | None) -> template_helper.Template:
     if isinstance(value, (list, dict, template_helper.Template)):
         raise vol.Invalid("template value should be a string")
 
-    template_value = template_helper.Template(str(value), async_get_hass_or_none())
+    template_value = template_helper.Template(str(value), _async_get_hass_or_none())
 
     try:
         template_value.ensure_valid()
@@ -682,7 +734,7 @@ def dynamic_template(value: Any | None) -> template_helper.Template:
     if not template_helper.is_template_string(str(value)):
         raise vol.Invalid("template value does not contain a dynamic template")
 
-    template_value = template_helper.Template(str(value), async_get_hass_or_none())
+    template_value = template_helper.Template(str(value), _async_get_hass_or_none())
 
     try:
         template_value.ensure_valid()
@@ -1918,3 +1970,32 @@ historic_currency = vol.In(
 country = vol.In(COUNTRIES, msg="invalid ISO 3166 formatted country")
 
 language = vol.In(LANGUAGES, msg="invalid RFC 5646 formatted language")
+
+
+async def async_validate(
+    hass: HomeAssistant, validator: Callable[[Any], Any], value: Any
+) -> Any:
+    """Async friendly schema validation.
+
+    If a validator decorated with @not_async_friendly is called, validation will be
+    deferred to an executor. If not, validation will happen in the event loop.
+    """
+    _validating_async.set(True)
+    try:
+        return validator(value)
+    except MustValidateInExecutor:
+        return await hass.async_add_executor_job(
+            _validate_in_executor, hass, validator, value
+        )
+    finally:
+        _validating_async.set(False)
+
+
+def _validate_in_executor(
+    hass: HomeAssistant, validator: Callable[[Any], Any], value: Any
+) -> Any:
+    _hass.hass = hass
+    try:
+        return validator(value)
+    finally:
+        _hass.hass = None
diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py
index ac3af13949b..973f504df08 100644
--- a/tests/helpers/test_config_validation.py
+++ b/tests/helpers/test_config_validation.py
@@ -3,13 +3,16 @@
 from collections import OrderedDict
 from datetime import date, datetime, timedelta
 import enum
+from functools import partial
 import logging
 import os
 from socket import _GLOBAL_DEFAULT_TIMEOUT
+import threading
 from typing import Any
-from unittest.mock import Mock, patch
+from unittest.mock import ANY, Mock, patch
 import uuid
 
+import py
 import pytest
 import voluptuous as vol
 
@@ -1738,3 +1741,67 @@ def test_determine_script_action_ambiguous() -> None:
 def test_determine_script_action_non_ambiguous() -> None:
     """Test determine script action with a non ambiguous action."""
     assert cv.determine_script_action({"delay": "00:00:05"}) == "delay"
+
+
+async def test_async_validate(hass: HomeAssistant, tmpdir: py.path.local) -> None:
+    """Test the async_validate helper."""
+    validator_calls: dict[str, list[int]] = {}
+
+    def _mock_validator_schema(real_func, *args):
+        calls = validator_calls.setdefault(real_func.__name__, [])
+        calls.append(threading.get_ident())
+        return real_func(*args)
+
+    CV_PREFIX = "homeassistant.helpers.config_validation"
+    with (
+        patch(f"{CV_PREFIX}.isdir", wraps=partial(_mock_validator_schema, cv.isdir)),
+        patch(f"{CV_PREFIX}.string", wraps=partial(_mock_validator_schema, cv.string)),
+    ):
+        # Assert validation in event loop when not decorated with not_async_friendly
+        await cv.async_validate(hass, cv.string, "abcd")
+        assert validator_calls == {"string": [hass.loop_thread_id]}
+        validator_calls = {}
+
+        # Assert validation in executor when decorated with not_async_friendly
+        await cv.async_validate(hass, cv.isdir, tmpdir)
+        assert validator_calls == {"isdir": [hass.loop_thread_id, ANY]}
+        assert validator_calls["isdir"][1] != hass.loop_thread_id
+        validator_calls = {}
+
+        # Assert validation in executor when decorated with not_async_friendly
+        await cv.async_validate(hass, vol.All(cv.isdir, cv.string), tmpdir)
+        assert validator_calls == {"isdir": [hass.loop_thread_id, ANY], "string": [ANY]}
+        assert validator_calls["isdir"][1] != hass.loop_thread_id
+        assert validator_calls["string"][0] != hass.loop_thread_id
+        validator_calls = {}
+
+        # Assert validation in executor when decorated with not_async_friendly
+        await cv.async_validate(hass, vol.All(cv.string, cv.isdir), tmpdir)
+        assert validator_calls == {
+            "isdir": [hass.loop_thread_id, ANY],
+            "string": [hass.loop_thread_id, ANY],
+        }
+        assert validator_calls["isdir"][1] != hass.loop_thread_id
+        assert validator_calls["string"][1] != hass.loop_thread_id
+        validator_calls = {}
+
+        # Assert validation in event loop when not using cv.async_validate
+        cv.isdir(tmpdir)
+        assert validator_calls == {"isdir": [hass.loop_thread_id]}
+        validator_calls = {}
+
+        # Assert validation in event loop when not using cv.async_validate
+        vol.All(cv.isdir, cv.string)(tmpdir)
+        assert validator_calls == {
+            "isdir": [hass.loop_thread_id],
+            "string": [hass.loop_thread_id],
+        }
+        validator_calls = {}
+
+        # Assert validation in event loop when not using cv.async_validate
+        vol.All(cv.string, cv.isdir)(tmpdir)
+        assert validator_calls == {
+            "isdir": [hass.loop_thread_id],
+            "string": [hass.loop_thread_id],
+        }
+        validator_calls = {}
-- 
GitLab