From cd57b764ce9b0f14568ff1c3da8893450b832e70 Mon Sep 17 00:00:00 2001 From: Eugenio Panadero <eugenio.panadero@gmail.com> Date: Tue, 24 Mar 2020 00:05:21 +0100 Subject: [PATCH] Fix state_automation_listener when new state is None (#32985) * Fix state_automation_listener when new state is None (fix #32984) * Listen to EVENT_STATE_CHANGED instead of using async_track_state_change and use the event context on automation trigger. * Share `process_state_match` with helpers/event * Add test for state change automation on entity removal --- homeassistant/components/automation/state.py | 47 +++++++++++++++----- homeassistant/helpers/event.py | 6 +-- tests/components/automation/test_state.py | 22 +++++++++ 3 files changed, 60 insertions(+), 15 deletions(-) diff --git a/homeassistant/components/automation/state.py b/homeassistant/components/automation/state.py index fc3fff47514..29aea64c9c5 100644 --- a/homeassistant/components/automation/state.py +++ b/homeassistant/components/automation/state.py @@ -6,10 +6,14 @@ from typing import Dict import voluptuous as vol from homeassistant import exceptions -from homeassistant.const import CONF_FOR, CONF_PLATFORM, MATCH_ALL +from homeassistant.const import CONF_FOR, CONF_PLATFORM, EVENT_STATE_CHANGED, MATCH_ALL from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.helpers import config_validation as cv, template -from homeassistant.helpers.event import async_track_same_state, async_track_state_change +from homeassistant.helpers.event import ( + Event, + async_track_same_state, + process_state_match, +) # mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs # mypy: no-check-untyped-defs @@ -56,10 +60,30 @@ async def async_attach_trigger( match_all = from_state == MATCH_ALL and to_state == MATCH_ALL unsub_track_same = {} period: Dict[str, timedelta] = {} + match_from_state = process_state_match(from_state) + match_to_state = process_state_match(to_state) @callback - def state_automation_listener(entity, from_s, to_s): + def state_automation_listener(event: Event): """Listen for state changes and calls action.""" + entity: str = event.data["entity_id"] + if entity not in entity_id: + return + + from_s = event.data.get("old_state") + to_s = event.data.get("new_state") + + if ( + (from_s is not None and not match_from_state(from_s.state)) + or (to_s is not None and not match_to_state(to_s.state)) + or ( + not match_all + and from_s is not None + and to_s is not None + and from_s.state == to_s.state + ) + ): + return @callback def call_action(): @@ -75,7 +99,7 @@ async def async_attach_trigger( "for": time_delta if not time_delta else period[entity], } }, - context=to_s.context, + context=event.context, ) ) @@ -120,17 +144,16 @@ async def async_attach_trigger( ) return + def _check_same_state(_, _2, new_st): + if new_st is None: + return False + return new_st.state == to_s.state + unsub_track_same[entity] = async_track_same_state( - hass, - period[entity], - call_action, - lambda _, _2, to_state: to_state.state == to_s.state, - entity_ids=entity, + hass, period[entity], call_action, _check_same_state, entity_ids=entity, ) - unsub = async_track_state_change( - hass, entity_id, state_automation_listener, from_state, to_state - ) + unsub = hass.bus.async_listen(EVENT_STATE_CHANGED, state_automation_listener) @callback def async_remove(): diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 74faca6a1d2..8a4b4bc2b76 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -67,8 +67,8 @@ def async_track_state_change( Must be run within the event loop. """ - match_from_state = _process_state_match(from_state) - match_to_state = _process_state_match(to_state) + match_from_state = process_state_match(from_state) + match_to_state = process_state_match(to_state) # Ensure it is a lowercase list with entity ids we want to match on if entity_ids == MATCH_ALL: @@ -473,7 +473,7 @@ def async_track_time_change( track_time_change = threaded_listener_factory(async_track_time_change) -def _process_state_match( +def process_state_match( parameter: Union[None, str, Iterable[str]] ) -> Callable[[str], bool]: """Convert parameter to function that matches input against parameter.""" diff --git a/tests/components/automation/test_state.py b/tests/components/automation/test_state.py index 9d4fa9a1100..173af8158a4 100644 --- a/tests/components/automation/test_state.py +++ b/tests/components/automation/test_state.py @@ -519,6 +519,28 @@ async def test_if_fires_on_entity_change_with_for(hass, calls): assert 1 == len(calls) +async def test_if_fires_on_entity_removal(hass, calls): + """Test for firing on entity removal, when new_state is None.""" + hass.states.async_set("test.entity", "hello") + await hass.async_block_till_done() + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "trigger": {"platform": "state", "entity_id": "test.entity"}, + "action": {"service": "test.automation"}, + } + }, + ) + await hass.async_block_till_done() + + assert hass.states.async_remove("test.entity") + await hass.async_block_till_done() + assert 1 == len(calls) + + async def test_if_fires_on_for_condition(hass, calls): """Test for firing if condition is on.""" point1 = dt_util.utcnow() -- GitLab