From b3377fe5fb9f16c69da55b1a93300cd6f6a75dc7 Mon Sep 17 00:00:00 2001
From: chammp <57918757+chammp@users.noreply.github.com>
Date: Wed, 11 Sep 2024 09:36:49 +0200
Subject: [PATCH] Add condition to trigger template entities (#119689)

* Add conditions to trigger template entities

* Add tests

* Fix ruff error

* Ruff

* Apply suggestions from code review

* Deduplicate

* Tweak name used in debug message

* Add and improve type annotations of modified code

* Adjust typing

* Adjust typing

* Add typing and remove unused parameter

* Adjust typing

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Adjust return type

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

---------

Co-authored-by: Erik Montnemery <erik@montnemery.com>
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
---
 .../components/automation/__init__.py         |  48 +----
 homeassistant/components/template/config.py   |   9 +-
 homeassistant/components/template/const.py    |   1 +
 .../components/template/coordinator.py        |  49 +++++-
 homeassistant/helpers/condition.py            |  41 +++++
 homeassistant/helpers/script.py               |   2 +-
 tests/components/template/test_sensor.py      | 164 ++++++++++++++++++
 7 files changed, 265 insertions(+), 49 deletions(-)

diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py
index 2081ea938ae..dacbe074e95 100644
--- a/homeassistant/components/automation/__init__.py
+++ b/homeassistant/components/automation/__init__.py
@@ -47,14 +47,7 @@ from homeassistant.core import (
     split_entity_id,
     valid_entity_id,
 )
-from homeassistant.exceptions import (
-    ConditionError,
-    ConditionErrorContainer,
-    ConditionErrorIndex,
-    HomeAssistantError,
-    ServiceNotFound,
-    TemplateError,
-)
+from homeassistant.exceptions import HomeAssistantError, ServiceNotFound, TemplateError
 from homeassistant.helpers import condition
 import homeassistant.helpers.config_validation as cv
 from homeassistant.helpers.deprecation import (
@@ -1146,38 +1139,13 @@ async def _async_process_if(
     """Process if checks."""
     if_configs = config[CONF_CONDITION]
 
-    checks: list[condition.ConditionCheckerType] = []
-    for if_config in if_configs:
-        try:
-            checks.append(await condition.async_from_config(hass, if_config))
-        except HomeAssistantError as ex:
-            LOGGER.warning("Invalid condition: %s", ex)
-            return None
-
-    def if_action(variables: Mapping[str, Any] | None = None) -> bool:
-        """AND all conditions."""
-        errors: list[ConditionErrorIndex] = []
-        for index, check in enumerate(checks):
-            try:
-                with trace_path(["condition", str(index)]):
-                    if check(hass, variables) is False:
-                        return False
-            except ConditionError as ex:
-                errors.append(
-                    ConditionErrorIndex(
-                        "condition", index=index, total=len(checks), error=ex
-                    )
-                )
-
-        if errors:
-            LOGGER.warning(
-                "Error evaluating condition in '%s':\n%s",
-                name,
-                ConditionErrorContainer("condition", errors=errors),
-            )
-            return False
-
-        return True
+    try:
+        if_action = await condition.async_conditions_from_config(
+            hass, if_configs, LOGGER, name
+        )
+    except HomeAssistantError as ex:
+        LOGGER.warning("Invalid condition: %s", ex)
+        return None
 
     result: IfAction = if_action  # type: ignore[assignment]
     result.config = if_configs
diff --git a/homeassistant/components/template/config.py b/homeassistant/components/template/config.py
index e2015743a0e..d75b111a6d0 100644
--- a/homeassistant/components/template/config.py
+++ b/homeassistant/components/template/config.py
@@ -15,6 +15,7 @@ from homeassistant.config import async_log_schema_error, config_without_domain
 from homeassistant.const import CONF_BINARY_SENSORS, CONF_SENSORS, CONF_UNIQUE_ID
 from homeassistant.core import HomeAssistant
 from homeassistant.helpers import config_validation as cv
+from homeassistant.helpers.condition import async_validate_conditions_config
 from homeassistant.helpers.trigger import async_validate_trigger_config
 from homeassistant.helpers.typing import ConfigType
 from homeassistant.setup import async_notify_setup_error
@@ -28,7 +29,7 @@ from . import (
     sensor as sensor_platform,
     weather as weather_platform,
 )
-from .const import CONF_ACTION, CONF_TRIGGER, DOMAIN
+from .const import CONF_ACTION, CONF_CONDITION, CONF_TRIGGER, DOMAIN
 
 PACKAGE_MERGE_HINT = "list"
 
@@ -36,6 +37,7 @@ CONFIG_SECTION_SCHEMA = vol.Schema(
     {
         vol.Optional(CONF_UNIQUE_ID): cv.string,
         vol.Optional(CONF_TRIGGER): cv.TRIGGER_SCHEMA,
+        vol.Optional(CONF_CONDITION): cv.CONDITIONS_SCHEMA,
         vol.Optional(CONF_ACTION): cv.SCRIPT_SCHEMA,
         vol.Optional(NUMBER_DOMAIN): vol.All(
             cv.ensure_list, [number_platform.NUMBER_SCHEMA]
@@ -83,6 +85,11 @@ async def async_validate_config(hass: HomeAssistant, config: ConfigType) -> Conf
                 cfg[CONF_TRIGGER] = await async_validate_trigger_config(
                     hass, cfg[CONF_TRIGGER]
                 )
+
+            if CONF_CONDITION in cfg:
+                cfg[CONF_CONDITION] = await async_validate_conditions_config(
+                    hass, cfg[CONF_CONDITION]
+                )
         except vol.Invalid as err:
             async_log_schema_error(err, DOMAIN, cfg, hass)
             async_notify_setup_error(hass, DOMAIN)
diff --git a/homeassistant/components/template/const.py b/homeassistant/components/template/const.py
index c320fc545b1..fc3f3c84b38 100644
--- a/homeassistant/components/template/const.py
+++ b/homeassistant/components/template/const.py
@@ -7,6 +7,7 @@ CONF_ATTRIBUTE_TEMPLATES = "attribute_templates"
 CONF_ATTRIBUTES = "attributes"
 CONF_AVAILABILITY = "availability"
 CONF_AVAILABILITY_TEMPLATE = "availability_template"
+CONF_CONDITION = "condition"
 CONF_MAX = "max"
 CONF_MIN = "min"
 CONF_OBJECT_ID = "object_id"
diff --git a/homeassistant/components/template/coordinator.py b/homeassistant/components/template/coordinator.py
index d2ce44a0ad1..50481d79d5b 100644
--- a/homeassistant/components/template/coordinator.py
+++ b/homeassistant/components/template/coordinator.py
@@ -1,16 +1,18 @@
 """Data update coordinator for trigger based template entities."""
 
-from collections.abc import Callable
+from collections.abc import Callable, Mapping
 import logging
+from typing import TYPE_CHECKING, Any
 
 from homeassistant.const import EVENT_HOMEASSISTANT_START
 from homeassistant.core import Context, CoreState, callback
-from homeassistant.helpers import discovery, trigger as trigger_helper
+from homeassistant.helpers import condition, discovery, trigger as trigger_helper
 from homeassistant.helpers.script import Script
-from homeassistant.helpers.typing import ConfigType
+from homeassistant.helpers.trace import trace_get
+from homeassistant.helpers.typing import ConfigType, TemplateVarsType
 from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
 
-from .const import CONF_ACTION, CONF_TRIGGER, DOMAIN, PLATFORMS
+from .const import CONF_ACTION, CONF_CONDITION, CONF_TRIGGER, DOMAIN, PLATFORMS
 
 _LOGGER = logging.getLogger(__name__)
 
@@ -24,6 +26,7 @@ class TriggerUpdateCoordinator(DataUpdateCoordinator):
         """Instantiate trigger data."""
         super().__init__(hass, _LOGGER, name="Trigger Update Coordinator")
         self.config = config
+        self._cond_func: Callable[[Mapping[str, Any] | None], bool] | None = None
         self._unsub_start: Callable[[], None] | None = None
         self._unsub_trigger: Callable[[], None] | None = None
         self._script: Script | None = None
@@ -73,6 +76,11 @@ class TriggerUpdateCoordinator(DataUpdateCoordinator):
                 DOMAIN,
             )
 
+        if CONF_CONDITION in self.config:
+            self._cond_func = await condition.async_conditions_from_config(
+                self.hass, self.config[CONF_CONDITION], _LOGGER, "template entity"
+            )
+
         if start_event is not None:
             self._unsub_start = None
 
@@ -91,16 +99,43 @@ class TriggerUpdateCoordinator(DataUpdateCoordinator):
             start_event is not None,
         )
 
-    async def _handle_triggered_with_script(self, run_variables, context=None):
+    async def _handle_triggered_with_script(
+        self, run_variables: TemplateVarsType, context: Context | None = None
+    ) -> None:
+        if not self._check_condition(run_variables):
+            return
         # Create a context referring to the trigger context.
         trigger_context_id = None if context is None else context.id
         script_context = Context(parent_id=trigger_context_id)
+        if TYPE_CHECKING:
+            # This method is only called if there's a script
+            assert self._script is not None
         if script_result := await self._script.async_run(run_variables, script_context):
             run_variables = script_result.variables
-        self._handle_triggered(run_variables, context)
+        self._execute_update(run_variables, context)
+
+    async def _handle_triggered(
+        self, run_variables: TemplateVarsType, context: Context | None = None
+    ) -> None:
+        if not self._check_condition(run_variables):
+            return
+        self._execute_update(run_variables, context)
+
+    def _check_condition(self, run_variables: TemplateVarsType) -> bool:
+        if not self._cond_func:
+            return True
+        condition_result = self._cond_func(run_variables)
+        if condition_result is False:
+            _LOGGER.debug(
+                "Conditions not met, aborting template trigger update. Condition summary: %s",
+                trace_get(clear=False),
+            )
+        return condition_result
 
     @callback
-    def _handle_triggered(self, run_variables, context=None):
+    def _execute_update(
+        self, run_variables: TemplateVarsType, context: Context | None = None
+    ) -> None:
         self.async_set_updated_data(
             {"run_variables": run_variables, "context": context}
         )
diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py
index 629cdeef942..86965f86d40 100644
--- a/homeassistant/helpers/condition.py
+++ b/homeassistant/helpers/condition.py
@@ -8,6 +8,7 @@ from collections.abc import Callable, Container, Generator
 from contextlib import contextmanager
 from datetime import datetime, time as dt_time, timedelta
 import functools as ft
+import logging
 import re
 import sys
 from typing import Any, Protocol, cast
@@ -1064,6 +1065,46 @@ async def async_validate_conditions_config(
     return [await async_validate_condition_config(hass, cond) for cond in conditions]
 
 
+async def async_conditions_from_config(
+    hass: HomeAssistant,
+    condition_configs: list[ConfigType],
+    logger: logging.Logger,
+    name: str,
+) -> Callable[[TemplateVarsType], bool]:
+    """AND all conditions."""
+    checks: list[ConditionCheckerType] = [
+        await async_from_config(hass, condition_config)
+        for condition_config in condition_configs
+    ]
+
+    def check_conditions(variables: TemplateVarsType = None) -> bool:
+        """AND all conditions."""
+        errors: list[ConditionErrorIndex] = []
+        for index, check in enumerate(checks):
+            try:
+                with trace_path(["condition", str(index)]):
+                    if check(hass, variables) is False:
+                        return False
+            except ConditionError as ex:
+                errors.append(
+                    ConditionErrorIndex(
+                        "condition", index=index, total=len(checks), error=ex
+                    )
+                )
+
+        if errors:
+            logger.warning(
+                "Error evaluating condition in '%s':\n%s",
+                name,
+                ConditionErrorContainer("condition", errors=errors),
+            )
+            return False
+
+        return True
+
+    return check_conditions
+
+
 @callback
 def async_extract_entities(config: ConfigType | Template) -> set[str]:
     """Extract entities from a condition."""
diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py
index 26a9b6e069e..0b5c0b99c35 100644
--- a/homeassistant/helpers/script.py
+++ b/homeassistant/helpers/script.py
@@ -1349,7 +1349,7 @@ async def _async_stop_scripts_at_shutdown(hass: HomeAssistant, event: Event) ->
         )
 
 
-type _VarsType = dict[str, Any] | MappingProxyType[str, Any]
+type _VarsType = dict[str, Any] | Mapping[str, Any] | MappingProxyType[str, Any]
 
 
 def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:
diff --git a/tests/components/template/test_sensor.py b/tests/components/template/test_sensor.py
index fb352ebcb8c..e5e6eba1068 100644
--- a/tests/components/template/test_sensor.py
+++ b/tests/components/template/test_sensor.py
@@ -1207,6 +1207,124 @@ async def test_trigger_entity(
     assert state.context is context
 
 
+@pytest.mark.parametrize(("count", "domain"), [(1, template.DOMAIN)])
+@pytest.mark.parametrize(
+    "config",
+    [
+        {
+            "template": [
+                {
+                    "unique_id": "listening-test-event",
+                    "trigger": {"platform": "event", "event_type": "test_event"},
+                    "condition": [
+                        {
+                            "condition": "template",
+                            "value_template": "{{ trigger.event.data.beer >= 42 }}",
+                        }
+                    ],
+                    "sensor": [
+                        {
+                            "name": "Enough Name",
+                            "unique_id": "enough-id",
+                            "state": "You had enough Beer.",
+                        }
+                    ],
+                },
+            ],
+        },
+    ],
+)
+async def test_trigger_conditional_entity(hass: HomeAssistant, start_ha) -> None:
+    """Test conditional trigger entity works."""
+    state = hass.states.get("sensor.enough_name")
+    assert state is not None
+    assert state.state == STATE_UNKNOWN
+
+    hass.bus.async_fire("test_event", {"beer": 2})
+    await hass.async_block_till_done()
+
+    state = hass.states.get("sensor.enough_name")
+    assert state.state == STATE_UNKNOWN
+
+    hass.bus.async_fire("test_event", {"beer": 42})
+    await hass.async_block_till_done()
+
+    state = hass.states.get("sensor.enough_name")
+    assert state.state == "You had enough Beer."
+
+
+@pytest.mark.parametrize(("count", "domain"), [(1, template.DOMAIN)])
+@pytest.mark.parametrize(
+    "config",
+    [
+        {
+            "template": [
+                {
+                    "unique_id": "listening-test-event",
+                    "trigger": {"platform": "event", "event_type": "test_event"},
+                    "condition": [
+                        {
+                            "condition": "template",
+                            "value_template": "{{ trigger.event.data.beer / 0 == 'narf' }}",
+                        }
+                    ],
+                    "sensor": [
+                        {
+                            "name": "Enough Name",
+                            "unique_id": "enough-id",
+                            "state": "You had enough Beer.",
+                        }
+                    ],
+                },
+            ],
+        },
+    ],
+)
+async def test_trigger_conditional_entity_evaluation_error(
+    hass: HomeAssistant, caplog: pytest.LogCaptureFixture, start_ha
+) -> None:
+    """Test trigger entity is not updated when condition evaluation fails."""
+    hass.bus.async_fire("test_event", {"beer": 1})
+    await hass.async_block_till_done()
+
+    state = hass.states.get("sensor.enough_name")
+    assert state is not None
+    assert state.state == STATE_UNKNOWN
+
+    assert "Error evaluating condition in 'template entity'" in caplog.text
+
+
+@pytest.mark.parametrize(("count", "domain"), [(0, template.DOMAIN)])
+@pytest.mark.parametrize(
+    "config",
+    [
+        {
+            "template": [
+                {
+                    "unique_id": "listening-test-event",
+                    "trigger": {"platform": "event", "event_type": "test_event"},
+                    "condition": [
+                        {"condition": "template", "value_template": "{{ invalid"}
+                    ],
+                    "sensor": [
+                        {
+                            "name": "Will Not Exist Name",
+                            "state": "Unimportant",
+                        }
+                    ],
+                },
+            ],
+        },
+    ],
+)
+async def test_trigger_conditional_entity_invalid_condition(
+    hass: HomeAssistant, start_ha
+) -> None:
+    """Test trigger entity is not created when condition is invalid."""
+    state = hass.states.get("sensor.will_not_exist_name")
+    assert state is None
+
+
 @pytest.mark.parametrize(("count", "domain"), [(1, "template")])
 @pytest.mark.parametrize(
     "config",
@@ -1903,6 +2021,52 @@ async def test_trigger_action(
     assert events[0].context.parent_id == context.id
 
 
+@pytest.mark.parametrize(("count", "domain"), [(1, template.DOMAIN)])
+@pytest.mark.parametrize(
+    "config",
+    [
+        {
+            "template": [
+                {
+                    "unique_id": "listening-test-event",
+                    "trigger": {"platform": "event", "event_type": "test_event"},
+                    "condition": [
+                        {
+                            "condition": "template",
+                            "value_template": "{{ trigger.event.data.beer >= 42 }}",
+                        }
+                    ],
+                    "action": [
+                        {"event": "test_event_by_action"},
+                    ],
+                    "sensor": [
+                        {
+                            "name": "Not That Important",
+                            "state": "Really not.",
+                        }
+                    ],
+                },
+            ],
+        },
+    ],
+)
+async def test_trigger_conditional_action(hass: HomeAssistant, start_ha) -> None:
+    """Test conditional trigger entity with an action works."""
+
+    event = "test_event_by_action"
+    events = async_capture_events(hass, event)
+
+    hass.bus.async_fire("test_event", {"beer": 1})
+    await hass.async_block_till_done()
+
+    assert len(events) == 0
+
+    hass.bus.async_fire("test_event", {"beer": 42})
+    await hass.async_block_till_done()
+
+    assert len(events) == 1
+
+
 async def test_device_id(
     hass: HomeAssistant,
     device_registry: dr.DeviceRegistry,
-- 
GitLab