From b3377fe5fb9f16c69da55b1a93300cd6f6a75dc7 Mon Sep 17 00:00:00 2001 From: chammp <57918757+chammp@users.noreply.github.com> Date: Wed, 11 Sep 2024 09:36:49 +0200 Subject: [PATCH] Add condition to trigger template entities (#119689) * Add conditions to trigger template entities * Add tests * Fix ruff error * Ruff * Apply suggestions from code review * Deduplicate * Tweak name used in debug message * Add and improve type annotations of modified code * Adjust typing * Adjust typing * Add typing and remove unused parameter * Adjust typing Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * Adjust return type Co-authored-by: Martin Hjelmare <marhje52@gmail.com> --------- Co-authored-by: Erik Montnemery <erik@montnemery.com> Co-authored-by: Martin Hjelmare <marhje52@gmail.com> --- .../components/automation/__init__.py | 48 +---- homeassistant/components/template/config.py | 9 +- homeassistant/components/template/const.py | 1 + .../components/template/coordinator.py | 49 +++++- homeassistant/helpers/condition.py | 41 +++++ homeassistant/helpers/script.py | 2 +- tests/components/template/test_sensor.py | 164 ++++++++++++++++++ 7 files changed, 265 insertions(+), 49 deletions(-) diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 2081ea938ae..dacbe074e95 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -47,14 +47,7 @@ from homeassistant.core import ( split_entity_id, valid_entity_id, ) -from homeassistant.exceptions import ( - ConditionError, - ConditionErrorContainer, - ConditionErrorIndex, - HomeAssistantError, - ServiceNotFound, - TemplateError, -) +from homeassistant.exceptions import HomeAssistantError, ServiceNotFound, TemplateError from homeassistant.helpers import condition import homeassistant.helpers.config_validation as cv from homeassistant.helpers.deprecation import ( @@ -1146,38 +1139,13 @@ async def _async_process_if( """Process if checks.""" if_configs = config[CONF_CONDITION] - checks: list[condition.ConditionCheckerType] = [] - for if_config in if_configs: - try: - checks.append(await condition.async_from_config(hass, if_config)) - except HomeAssistantError as ex: - LOGGER.warning("Invalid condition: %s", ex) - return None - - def if_action(variables: Mapping[str, Any] | None = None) -> bool: - """AND all conditions.""" - errors: list[ConditionErrorIndex] = [] - for index, check in enumerate(checks): - try: - with trace_path(["condition", str(index)]): - if check(hass, variables) is False: - return False - except ConditionError as ex: - errors.append( - ConditionErrorIndex( - "condition", index=index, total=len(checks), error=ex - ) - ) - - if errors: - LOGGER.warning( - "Error evaluating condition in '%s':\n%s", - name, - ConditionErrorContainer("condition", errors=errors), - ) - return False - - return True + try: + if_action = await condition.async_conditions_from_config( + hass, if_configs, LOGGER, name + ) + except HomeAssistantError as ex: + LOGGER.warning("Invalid condition: %s", ex) + return None result: IfAction = if_action # type: ignore[assignment] result.config = if_configs diff --git a/homeassistant/components/template/config.py b/homeassistant/components/template/config.py index e2015743a0e..d75b111a6d0 100644 --- a/homeassistant/components/template/config.py +++ b/homeassistant/components/template/config.py @@ -15,6 +15,7 @@ from homeassistant.config import async_log_schema_error, config_without_domain from homeassistant.const import CONF_BINARY_SENSORS, CONF_SENSORS, CONF_UNIQUE_ID from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.condition import async_validate_conditions_config from homeassistant.helpers.trigger import async_validate_trigger_config from homeassistant.helpers.typing import ConfigType from homeassistant.setup import async_notify_setup_error @@ -28,7 +29,7 @@ from . import ( sensor as sensor_platform, weather as weather_platform, ) -from .const import CONF_ACTION, CONF_TRIGGER, DOMAIN +from .const import CONF_ACTION, CONF_CONDITION, CONF_TRIGGER, DOMAIN PACKAGE_MERGE_HINT = "list" @@ -36,6 +37,7 @@ CONFIG_SECTION_SCHEMA = vol.Schema( { vol.Optional(CONF_UNIQUE_ID): cv.string, vol.Optional(CONF_TRIGGER): cv.TRIGGER_SCHEMA, + vol.Optional(CONF_CONDITION): cv.CONDITIONS_SCHEMA, vol.Optional(CONF_ACTION): cv.SCRIPT_SCHEMA, vol.Optional(NUMBER_DOMAIN): vol.All( cv.ensure_list, [number_platform.NUMBER_SCHEMA] @@ -83,6 +85,11 @@ async def async_validate_config(hass: HomeAssistant, config: ConfigType) -> Conf cfg[CONF_TRIGGER] = await async_validate_trigger_config( hass, cfg[CONF_TRIGGER] ) + + if CONF_CONDITION in cfg: + cfg[CONF_CONDITION] = await async_validate_conditions_config( + hass, cfg[CONF_CONDITION] + ) except vol.Invalid as err: async_log_schema_error(err, DOMAIN, cfg, hass) async_notify_setup_error(hass, DOMAIN) diff --git a/homeassistant/components/template/const.py b/homeassistant/components/template/const.py index c320fc545b1..fc3f3c84b38 100644 --- a/homeassistant/components/template/const.py +++ b/homeassistant/components/template/const.py @@ -7,6 +7,7 @@ CONF_ATTRIBUTE_TEMPLATES = "attribute_templates" CONF_ATTRIBUTES = "attributes" CONF_AVAILABILITY = "availability" CONF_AVAILABILITY_TEMPLATE = "availability_template" +CONF_CONDITION = "condition" CONF_MAX = "max" CONF_MIN = "min" CONF_OBJECT_ID = "object_id" diff --git a/homeassistant/components/template/coordinator.py b/homeassistant/components/template/coordinator.py index d2ce44a0ad1..50481d79d5b 100644 --- a/homeassistant/components/template/coordinator.py +++ b/homeassistant/components/template/coordinator.py @@ -1,16 +1,18 @@ """Data update coordinator for trigger based template entities.""" -from collections.abc import Callable +from collections.abc import Callable, Mapping import logging +from typing import TYPE_CHECKING, Any from homeassistant.const import EVENT_HOMEASSISTANT_START from homeassistant.core import Context, CoreState, callback -from homeassistant.helpers import discovery, trigger as trigger_helper +from homeassistant.helpers import condition, discovery, trigger as trigger_helper from homeassistant.helpers.script import Script -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.trace import trace_get +from homeassistant.helpers.typing import ConfigType, TemplateVarsType from homeassistant.helpers.update_coordinator import DataUpdateCoordinator -from .const import CONF_ACTION, CONF_TRIGGER, DOMAIN, PLATFORMS +from .const import CONF_ACTION, CONF_CONDITION, CONF_TRIGGER, DOMAIN, PLATFORMS _LOGGER = logging.getLogger(__name__) @@ -24,6 +26,7 @@ class TriggerUpdateCoordinator(DataUpdateCoordinator): """Instantiate trigger data.""" super().__init__(hass, _LOGGER, name="Trigger Update Coordinator") self.config = config + self._cond_func: Callable[[Mapping[str, Any] | None], bool] | None = None self._unsub_start: Callable[[], None] | None = None self._unsub_trigger: Callable[[], None] | None = None self._script: Script | None = None @@ -73,6 +76,11 @@ class TriggerUpdateCoordinator(DataUpdateCoordinator): DOMAIN, ) + if CONF_CONDITION in self.config: + self._cond_func = await condition.async_conditions_from_config( + self.hass, self.config[CONF_CONDITION], _LOGGER, "template entity" + ) + if start_event is not None: self._unsub_start = None @@ -91,16 +99,43 @@ class TriggerUpdateCoordinator(DataUpdateCoordinator): start_event is not None, ) - async def _handle_triggered_with_script(self, run_variables, context=None): + async def _handle_triggered_with_script( + self, run_variables: TemplateVarsType, context: Context | None = None + ) -> None: + if not self._check_condition(run_variables): + return # Create a context referring to the trigger context. trigger_context_id = None if context is None else context.id script_context = Context(parent_id=trigger_context_id) + if TYPE_CHECKING: + # This method is only called if there's a script + assert self._script is not None if script_result := await self._script.async_run(run_variables, script_context): run_variables = script_result.variables - self._handle_triggered(run_variables, context) + self._execute_update(run_variables, context) + + async def _handle_triggered( + self, run_variables: TemplateVarsType, context: Context | None = None + ) -> None: + if not self._check_condition(run_variables): + return + self._execute_update(run_variables, context) + + def _check_condition(self, run_variables: TemplateVarsType) -> bool: + if not self._cond_func: + return True + condition_result = self._cond_func(run_variables) + if condition_result is False: + _LOGGER.debug( + "Conditions not met, aborting template trigger update. Condition summary: %s", + trace_get(clear=False), + ) + return condition_result @callback - def _handle_triggered(self, run_variables, context=None): + def _execute_update( + self, run_variables: TemplateVarsType, context: Context | None = None + ) -> None: self.async_set_updated_data( {"run_variables": run_variables, "context": context} ) diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index 629cdeef942..86965f86d40 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -8,6 +8,7 @@ from collections.abc import Callable, Container, Generator from contextlib import contextmanager from datetime import datetime, time as dt_time, timedelta import functools as ft +import logging import re import sys from typing import Any, Protocol, cast @@ -1064,6 +1065,46 @@ async def async_validate_conditions_config( return [await async_validate_condition_config(hass, cond) for cond in conditions] +async def async_conditions_from_config( + hass: HomeAssistant, + condition_configs: list[ConfigType], + logger: logging.Logger, + name: str, +) -> Callable[[TemplateVarsType], bool]: + """AND all conditions.""" + checks: list[ConditionCheckerType] = [ + await async_from_config(hass, condition_config) + for condition_config in condition_configs + ] + + def check_conditions(variables: TemplateVarsType = None) -> bool: + """AND all conditions.""" + errors: list[ConditionErrorIndex] = [] + for index, check in enumerate(checks): + try: + with trace_path(["condition", str(index)]): + if check(hass, variables) is False: + return False + except ConditionError as ex: + errors.append( + ConditionErrorIndex( + "condition", index=index, total=len(checks), error=ex + ) + ) + + if errors: + logger.warning( + "Error evaluating condition in '%s':\n%s", + name, + ConditionErrorContainer("condition", errors=errors), + ) + return False + + return True + + return check_conditions + + @callback def async_extract_entities(config: ConfigType | Template) -> set[str]: """Extract entities from a condition.""" diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 26a9b6e069e..0b5c0b99c35 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -1349,7 +1349,7 @@ async def _async_stop_scripts_at_shutdown(hass: HomeAssistant, event: Event) -> ) -type _VarsType = dict[str, Any] | MappingProxyType[str, Any] +type _VarsType = dict[str, Any] | Mapping[str, Any] | MappingProxyType[str, Any] def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None: diff --git a/tests/components/template/test_sensor.py b/tests/components/template/test_sensor.py index fb352ebcb8c..e5e6eba1068 100644 --- a/tests/components/template/test_sensor.py +++ b/tests/components/template/test_sensor.py @@ -1207,6 +1207,124 @@ async def test_trigger_entity( assert state.context is context +@pytest.mark.parametrize(("count", "domain"), [(1, template.DOMAIN)]) +@pytest.mark.parametrize( + "config", + [ + { + "template": [ + { + "unique_id": "listening-test-event", + "trigger": {"platform": "event", "event_type": "test_event"}, + "condition": [ + { + "condition": "template", + "value_template": "{{ trigger.event.data.beer >= 42 }}", + } + ], + "sensor": [ + { + "name": "Enough Name", + "unique_id": "enough-id", + "state": "You had enough Beer.", + } + ], + }, + ], + }, + ], +) +async def test_trigger_conditional_entity(hass: HomeAssistant, start_ha) -> None: + """Test conditional trigger entity works.""" + state = hass.states.get("sensor.enough_name") + assert state is not None + assert state.state == STATE_UNKNOWN + + hass.bus.async_fire("test_event", {"beer": 2}) + await hass.async_block_till_done() + + state = hass.states.get("sensor.enough_name") + assert state.state == STATE_UNKNOWN + + hass.bus.async_fire("test_event", {"beer": 42}) + await hass.async_block_till_done() + + state = hass.states.get("sensor.enough_name") + assert state.state == "You had enough Beer." + + +@pytest.mark.parametrize(("count", "domain"), [(1, template.DOMAIN)]) +@pytest.mark.parametrize( + "config", + [ + { + "template": [ + { + "unique_id": "listening-test-event", + "trigger": {"platform": "event", "event_type": "test_event"}, + "condition": [ + { + "condition": "template", + "value_template": "{{ trigger.event.data.beer / 0 == 'narf' }}", + } + ], + "sensor": [ + { + "name": "Enough Name", + "unique_id": "enough-id", + "state": "You had enough Beer.", + } + ], + }, + ], + }, + ], +) +async def test_trigger_conditional_entity_evaluation_error( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture, start_ha +) -> None: + """Test trigger entity is not updated when condition evaluation fails.""" + hass.bus.async_fire("test_event", {"beer": 1}) + await hass.async_block_till_done() + + state = hass.states.get("sensor.enough_name") + assert state is not None + assert state.state == STATE_UNKNOWN + + assert "Error evaluating condition in 'template entity'" in caplog.text + + +@pytest.mark.parametrize(("count", "domain"), [(0, template.DOMAIN)]) +@pytest.mark.parametrize( + "config", + [ + { + "template": [ + { + "unique_id": "listening-test-event", + "trigger": {"platform": "event", "event_type": "test_event"}, + "condition": [ + {"condition": "template", "value_template": "{{ invalid"} + ], + "sensor": [ + { + "name": "Will Not Exist Name", + "state": "Unimportant", + } + ], + }, + ], + }, + ], +) +async def test_trigger_conditional_entity_invalid_condition( + hass: HomeAssistant, start_ha +) -> None: + """Test trigger entity is not created when condition is invalid.""" + state = hass.states.get("sensor.will_not_exist_name") + assert state is None + + @pytest.mark.parametrize(("count", "domain"), [(1, "template")]) @pytest.mark.parametrize( "config", @@ -1903,6 +2021,52 @@ async def test_trigger_action( assert events[0].context.parent_id == context.id +@pytest.mark.parametrize(("count", "domain"), [(1, template.DOMAIN)]) +@pytest.mark.parametrize( + "config", + [ + { + "template": [ + { + "unique_id": "listening-test-event", + "trigger": {"platform": "event", "event_type": "test_event"}, + "condition": [ + { + "condition": "template", + "value_template": "{{ trigger.event.data.beer >= 42 }}", + } + ], + "action": [ + {"event": "test_event_by_action"}, + ], + "sensor": [ + { + "name": "Not That Important", + "state": "Really not.", + } + ], + }, + ], + }, + ], +) +async def test_trigger_conditional_action(hass: HomeAssistant, start_ha) -> None: + """Test conditional trigger entity with an action works.""" + + event = "test_event_by_action" + events = async_capture_events(hass, event) + + hass.bus.async_fire("test_event", {"beer": 1}) + await hass.async_block_till_done() + + assert len(events) == 0 + + hass.bus.async_fire("test_event", {"beer": 42}) + await hass.async_block_till_done() + + assert len(events) == 1 + + async def test_device_id( hass: HomeAssistant, device_registry: dr.DeviceRegistry, -- GitLab