diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index d7a5d5ae8a183e6798720e47d7129915885d7e44..2904efb75e9f82bceefbc66fb574b4598763908d 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -1305,9 +1305,28 @@ TARGET_SERVICE_FIELDS = { _HAS_ENTITY_SERVICE_FIELD = has_at_least_one_key(*ENTITY_SERVICE_FIELDS) +def is_entity_service_schema(validator: VolSchemaType) -> bool: + """Check if the passed validator is an entity schema validator. + + The validator must be either of: + - A validator returned by cv._make_entity_service_schema + - A validator returned by cv._make_entity_service_schema, wrapped in a vol.Schema + - A validator returned by cv._make_entity_service_schema, wrapped in a vol.All + Nesting is allowed. + """ + if hasattr(validator, "_entity_service_schema"): + return True + if isinstance(validator, (vol.All)): + return any(is_entity_service_schema(val) for val in validator.validators) + if isinstance(validator, (vol.Schema)): + return is_entity_service_schema(validator.schema) + + return False + + def _make_entity_service_schema(schema: dict, extra: int) -> VolSchemaType: """Create an entity service schema.""" - return vol.All( + validator = vol.All( vol.Schema( { # The frontend stores data here. Don't use in core. @@ -1319,6 +1338,8 @@ def _make_entity_service_schema(schema: dict, extra: int) -> VolSchemaType: ), _HAS_ENTITY_SERVICE_FIELD, ) + setattr(validator, "_entity_service_schema", True) + return validator BASE_ENTITY_SCHEMA = _make_entity_service_schema({}, vol.PREVENT_EXTRA) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 0551b5289c548d5f11f0782a2174b67586bd1a92..573073f380916a1d9e583f7661b0171a5fe0350c 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -1267,17 +1267,8 @@ def async_register_entity_service( # Do a sanity check to check this is a valid entity service schema, # the check could be extended to require All/Any to have sub schema(s) # with all entity service fields - elif ( - # Don't check All/Any - not isinstance(schema, (vol.All, vol.Any)) - # Don't check All/Any wrapped in schema - and not isinstance(schema.schema, (vol.All, vol.Any)) - and any(key not in schema.schema for key in cv.ENTITY_SERVICE_FIELDS) - ): - raise HomeAssistantError( - "The schema does not include all required keys: " - f"{", ".join(str(key) for key in cv.ENTITY_SERVICE_FIELDS)}" - ) + elif not cv.is_entity_service_schema(schema): + raise HomeAssistantError("The schema is not an entity service schema") service_func: str | HassJob[..., Any] service_func = func if isinstance(func, str) else HassJob(func) diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py index 973f504df08edf8a462b04f64dff9bda41190bde..57c712e2f10d05ea0a9cf675fd0976ae6dc1eda3 100644 --- a/tests/helpers/test_config_validation.py +++ b/tests/helpers/test_config_validation.py @@ -1805,3 +1805,27 @@ async def test_async_validate(hass: HomeAssistant, tmpdir: py.path.local) -> Non "string": [hass.loop_thread_id], } validator_calls = {} + + +async def test_is_entity_service_schema( + hass: HomeAssistant, +) -> None: + """Test cv.is_entity_service_schema.""" + for schema in ( + vol.Schema({"some": str}), + vol.All(vol.Schema({"some": str})), + vol.Any(vol.Schema({"some": str})), + vol.Any(cv.make_entity_service_schema({"some": str})), + ): + assert cv.is_entity_service_schema(schema) is False + + for schema in ( + cv.make_entity_service_schema({"some": str}), + vol.Schema(cv.make_entity_service_schema({"some": str})), + vol.Schema(vol.All(cv.make_entity_service_schema({"some": str}))), + vol.Schema(vol.Schema(cv.make_entity_service_schema({"some": str}))), + vol.All(cv.make_entity_service_schema({"some": str})), + vol.All(vol.All(cv.make_entity_service_schema({"some": str}))), + vol.All(vol.Schema(cv.make_entity_service_schema({"some": str}))), + ): + assert cv.is_entity_service_schema(schema) is True diff --git a/tests/helpers/test_entity_component.py b/tests/helpers/test_entity_component.py index 5ce0292c2ec5fa1c128fb9afdc4f5aff51ef2bf6..8f4ece09a1734cd014453b1cfef8f0d146e48c02 100644 --- a/tests/helpers/test_entity_component.py +++ b/tests/helpers/test_entity_component.py @@ -23,7 +23,7 @@ from homeassistant.core import ( callback, ) from homeassistant.exceptions import HomeAssistantError, PlatformNotReady -from homeassistant.helpers import discovery +from homeassistant.helpers import config_validation as cv, discovery from homeassistant.helpers.entity_component import EntityComponent, async_update_entity from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType @@ -559,28 +559,28 @@ async def test_register_entity_service( async def test_register_entity_service_non_entity_service_schema( hass: HomeAssistant, ) -> None: - """Test attempting to register a service with an incomplete schema.""" + """Test attempting to register a service with a non entity service schema.""" component = EntityComponent(_LOGGER, DOMAIN, hass) - with pytest.raises( - HomeAssistantError, - match=( - "The schema does not include all required keys: entity_id, device_id, area_id, " - "floor_id, label_id" - ), + for schema in ( + vol.Schema({"some": str}), + vol.All(vol.Schema({"some": str})), + vol.Any(vol.Schema({"some": str})), ): - component.async_register_entity_service( - "hello", vol.Schema({"some": str}), Mock() + with pytest.raises( + HomeAssistantError, + match=("The schema is not an entity service schema"), + ): + component.async_register_entity_service("hello", schema, Mock()) + + for idx, schema in enumerate( + ( + cv.make_entity_service_schema({"some": str}), + vol.Schema(cv.make_entity_service_schema({"some": str})), + vol.All(cv.make_entity_service_schema({"some": str})), ) - - # The check currently does not recurse into vol.All or vol.Any allowing these - # non-compliant schemas to pass - component.async_register_entity_service( - "hello", vol.All(vol.Schema({"some": str})), Mock() - ) - component.async_register_entity_service( - "hello", vol.Any(vol.Schema({"some": str})), Mock() - ) + ): + component.async_register_entity_service(f"test_service_{idx}", schema, Mock()) async def test_register_entity_service_response_data(hass: HomeAssistant) -> None: diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index 2cc3348626cbd0547e921f7baf388341dbd50ab3..2b0598cfe9d8568c6b40628126f3aebf1e21b12d 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -23,6 +23,7 @@ from homeassistant.core import ( from homeassistant.exceptions import HomeAssistantError, PlatformNotReady from homeassistant.helpers import ( area_registry as ar, + config_validation as cv, device_registry as dr, entity_platform, entity_registry as er, @@ -1812,31 +1813,32 @@ async def test_register_entity_service_none_schema( async def test_register_entity_service_non_entity_service_schema( hass: HomeAssistant, ) -> None: - """Test attempting to register a service with an incomplete schema.""" + """Test attempting to register a service with a non entity service schema.""" entity_platform = MockEntityPlatform( hass, domain="mock_integration", platform_name="mock_platform", platform=None ) - with pytest.raises( - HomeAssistantError, - match=( - "The schema does not include all required keys: entity_id, device_id, area_id, " - "floor_id, label_id" - ), + for schema in ( + vol.Schema({"some": str}), + vol.All(vol.Schema({"some": str})), + vol.Any(vol.Schema({"some": str})), + ): + with pytest.raises( + HomeAssistantError, + match="The schema is not an entity service schema", + ): + entity_platform.async_register_entity_service("hello", schema, Mock()) + + for idx, schema in enumerate( + ( + cv.make_entity_service_schema({"some": str}), + vol.Schema(cv.make_entity_service_schema({"some": str})), + vol.All(cv.make_entity_service_schema({"some": str})), + ) ): entity_platform.async_register_entity_service( - "hello", - vol.Schema({"some": str}), - Mock(), + f"test_service_{idx}", schema, Mock() ) - # The check currently does not recurse into vol.All or vol.Any allowing these - # non-compliant schemas to pass - entity_platform.async_register_entity_service( - "hello", vol.All(vol.Schema({"some": str})), Mock() - ) - entity_platform.async_register_entity_service( - "hello", vol.Any(vol.Schema({"some": str})), Mock() - ) @pytest.mark.parametrize("update_before_add", [True, False])