From 55c42fde88c7a6989a8489b755375fef92360f35 Mon Sep 17 00:00:00 2001 From: Erik Montnemery <erik@montnemery.com> Date: Tue, 27 Aug 2024 19:05:49 +0200 Subject: [PATCH] Improve validation of entity service schemas (#124102) * Improve validation of entity service schemas * Update tests/helpers/test_entity_platform.py Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com> --------- Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com> --- homeassistant/helpers/config_validation.py | 23 ++++++++++++- homeassistant/helpers/service.py | 13 ++------ tests/helpers/test_config_validation.py | 24 ++++++++++++++ tests/helpers/test_entity_component.py | 38 +++++++++++----------- tests/helpers/test_entity_platform.py | 38 ++++++++++++---------- 5 files changed, 87 insertions(+), 49 deletions(-) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index d7a5d5ae8a1..2904efb75e9 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -1305,9 +1305,28 @@ TARGET_SERVICE_FIELDS = { _HAS_ENTITY_SERVICE_FIELD = has_at_least_one_key(*ENTITY_SERVICE_FIELDS) +def is_entity_service_schema(validator: VolSchemaType) -> bool: + """Check if the passed validator is an entity schema validator. + + The validator must be either of: + - A validator returned by cv._make_entity_service_schema + - A validator returned by cv._make_entity_service_schema, wrapped in a vol.Schema + - A validator returned by cv._make_entity_service_schema, wrapped in a vol.All + Nesting is allowed. + """ + if hasattr(validator, "_entity_service_schema"): + return True + if isinstance(validator, (vol.All)): + return any(is_entity_service_schema(val) for val in validator.validators) + if isinstance(validator, (vol.Schema)): + return is_entity_service_schema(validator.schema) + + return False + + def _make_entity_service_schema(schema: dict, extra: int) -> VolSchemaType: """Create an entity service schema.""" - return vol.All( + validator = vol.All( vol.Schema( { # The frontend stores data here. Don't use in core. @@ -1319,6 +1338,8 @@ def _make_entity_service_schema(schema: dict, extra: int) -> VolSchemaType: ), _HAS_ENTITY_SERVICE_FIELD, ) + setattr(validator, "_entity_service_schema", True) + return validator BASE_ENTITY_SCHEMA = _make_entity_service_schema({}, vol.PREVENT_EXTRA) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 0551b5289c5..573073f3809 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -1267,17 +1267,8 @@ def async_register_entity_service( # Do a sanity check to check this is a valid entity service schema, # the check could be extended to require All/Any to have sub schema(s) # with all entity service fields - elif ( - # Don't check All/Any - not isinstance(schema, (vol.All, vol.Any)) - # Don't check All/Any wrapped in schema - and not isinstance(schema.schema, (vol.All, vol.Any)) - and any(key not in schema.schema for key in cv.ENTITY_SERVICE_FIELDS) - ): - raise HomeAssistantError( - "The schema does not include all required keys: " - f"{", ".join(str(key) for key in cv.ENTITY_SERVICE_FIELDS)}" - ) + elif not cv.is_entity_service_schema(schema): + raise HomeAssistantError("The schema is not an entity service schema") service_func: str | HassJob[..., Any] service_func = func if isinstance(func, str) else HassJob(func) diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py index 973f504df08..57c712e2f10 100644 --- a/tests/helpers/test_config_validation.py +++ b/tests/helpers/test_config_validation.py @@ -1805,3 +1805,27 @@ async def test_async_validate(hass: HomeAssistant, tmpdir: py.path.local) -> Non "string": [hass.loop_thread_id], } validator_calls = {} + + +async def test_is_entity_service_schema( + hass: HomeAssistant, +) -> None: + """Test cv.is_entity_service_schema.""" + for schema in ( + vol.Schema({"some": str}), + vol.All(vol.Schema({"some": str})), + vol.Any(vol.Schema({"some": str})), + vol.Any(cv.make_entity_service_schema({"some": str})), + ): + assert cv.is_entity_service_schema(schema) is False + + for schema in ( + cv.make_entity_service_schema({"some": str}), + vol.Schema(cv.make_entity_service_schema({"some": str})), + vol.Schema(vol.All(cv.make_entity_service_schema({"some": str}))), + vol.Schema(vol.Schema(cv.make_entity_service_schema({"some": str}))), + vol.All(cv.make_entity_service_schema({"some": str})), + vol.All(vol.All(cv.make_entity_service_schema({"some": str}))), + vol.All(vol.Schema(cv.make_entity_service_schema({"some": str}))), + ): + assert cv.is_entity_service_schema(schema) is True diff --git a/tests/helpers/test_entity_component.py b/tests/helpers/test_entity_component.py index 5ce0292c2ec..8f4ece09a17 100644 --- a/tests/helpers/test_entity_component.py +++ b/tests/helpers/test_entity_component.py @@ -23,7 +23,7 @@ from homeassistant.core import ( callback, ) from homeassistant.exceptions import HomeAssistantError, PlatformNotReady -from homeassistant.helpers import discovery +from homeassistant.helpers import config_validation as cv, discovery from homeassistant.helpers.entity_component import EntityComponent, async_update_entity from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType @@ -559,28 +559,28 @@ async def test_register_entity_service( async def test_register_entity_service_non_entity_service_schema( hass: HomeAssistant, ) -> None: - """Test attempting to register a service with an incomplete schema.""" + """Test attempting to register a service with a non entity service schema.""" component = EntityComponent(_LOGGER, DOMAIN, hass) - with pytest.raises( - HomeAssistantError, - match=( - "The schema does not include all required keys: entity_id, device_id, area_id, " - "floor_id, label_id" - ), + for schema in ( + vol.Schema({"some": str}), + vol.All(vol.Schema({"some": str})), + vol.Any(vol.Schema({"some": str})), ): - component.async_register_entity_service( - "hello", vol.Schema({"some": str}), Mock() + with pytest.raises( + HomeAssistantError, + match=("The schema is not an entity service schema"), + ): + component.async_register_entity_service("hello", schema, Mock()) + + for idx, schema in enumerate( + ( + cv.make_entity_service_schema({"some": str}), + vol.Schema(cv.make_entity_service_schema({"some": str})), + vol.All(cv.make_entity_service_schema({"some": str})), ) - - # The check currently does not recurse into vol.All or vol.Any allowing these - # non-compliant schemas to pass - component.async_register_entity_service( - "hello", vol.All(vol.Schema({"some": str})), Mock() - ) - component.async_register_entity_service( - "hello", vol.Any(vol.Schema({"some": str})), Mock() - ) + ): + component.async_register_entity_service(f"test_service_{idx}", schema, Mock()) async def test_register_entity_service_response_data(hass: HomeAssistant) -> None: diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index 2cc3348626c..2b0598cfe9d 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -23,6 +23,7 @@ from homeassistant.core import ( from homeassistant.exceptions import HomeAssistantError, PlatformNotReady from homeassistant.helpers import ( area_registry as ar, + config_validation as cv, device_registry as dr, entity_platform, entity_registry as er, @@ -1812,31 +1813,32 @@ async def test_register_entity_service_none_schema( async def test_register_entity_service_non_entity_service_schema( hass: HomeAssistant, ) -> None: - """Test attempting to register a service with an incomplete schema.""" + """Test attempting to register a service with a non entity service schema.""" entity_platform = MockEntityPlatform( hass, domain="mock_integration", platform_name="mock_platform", platform=None ) - with pytest.raises( - HomeAssistantError, - match=( - "The schema does not include all required keys: entity_id, device_id, area_id, " - "floor_id, label_id" - ), + for schema in ( + vol.Schema({"some": str}), + vol.All(vol.Schema({"some": str})), + vol.Any(vol.Schema({"some": str})), + ): + with pytest.raises( + HomeAssistantError, + match="The schema is not an entity service schema", + ): + entity_platform.async_register_entity_service("hello", schema, Mock()) + + for idx, schema in enumerate( + ( + cv.make_entity_service_schema({"some": str}), + vol.Schema(cv.make_entity_service_schema({"some": str})), + vol.All(cv.make_entity_service_schema({"some": str})), + ) ): entity_platform.async_register_entity_service( - "hello", - vol.Schema({"some": str}), - Mock(), + f"test_service_{idx}", schema, Mock() ) - # The check currently does not recurse into vol.All or vol.Any allowing these - # non-compliant schemas to pass - entity_platform.async_register_entity_service( - "hello", vol.All(vol.Schema({"some": str})), Mock() - ) - entity_platform.async_register_entity_service( - "hello", vol.Any(vol.Schema({"some": str})), Mock() - ) @pytest.mark.parametrize("update_before_add", [True, False]) -- GitLab