diff --git a/homeassistant/const.py b/homeassistant/const.py index d27b5e3a1b79676f11043257ce1cb0f9c437824d..b9b71734241b7ca258a07e6927805e6d9fe5871b 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -225,6 +225,9 @@ ATTR_ID = 'id' # Name ATTR_NAME = 'name' +# Data for a SERVICE_EXECUTED event +ATTR_SERVICE_CALL_ID = 'service_call_id' + # Contains one string or a list of strings, each being an entity id ATTR_ENTITY_ID = 'entity_id' diff --git a/homeassistant/core.py b/homeassistant/core.py index fdbbe49ea05d69c9e88260c487bc2054781cbae3..39ee20cb1a8c824b2b65b1664f3074eefed84fb8 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -29,7 +29,7 @@ from voluptuous.humanize import humanize_error from homeassistant.const import ( ATTR_DOMAIN, ATTR_FRIENDLY_NAME, ATTR_NOW, ATTR_SERVICE, - ATTR_SERVICE_DATA, EVENT_CALL_SERVICE, + ATTR_SERVICE_CALL_ID, ATTR_SERVICE_DATA, EVENT_CALL_SERVICE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE, @@ -1042,10 +1042,12 @@ class ServiceRegistry: This method is a coroutine. """ context = context or Context() + call_id = uuid.uuid4().hex event_data = { ATTR_DOMAIN: domain.lower(), ATTR_SERVICE: service.lower(), ATTR_SERVICE_DATA: service_data, + ATTR_SERVICE_CALL_ID: call_id, } if not blocking: @@ -1058,8 +1060,9 @@ class ServiceRegistry: @callback def service_executed(event: Event) -> None: """Handle an executed service.""" - if event.context == context: + if event.data[ATTR_SERVICE_CALL_ID] == call_id: fut.set_result(True) + unsub() unsub = self._hass.bus.async_listen( EVENT_SERVICE_EXECUTED, service_executed) @@ -1069,7 +1072,8 @@ class ServiceRegistry: done, _ = await asyncio.wait([fut], timeout=SERVICE_CALL_LIMIT) success = bool(done) - unsub() + if not success: + unsub() return success async def _event_to_service_call(self, event: Event) -> None: @@ -1077,6 +1081,7 @@ class ServiceRegistry: service_data = event.data.get(ATTR_SERVICE_DATA) or {} domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore service = event.data.get(ATTR_SERVICE).lower() # type: ignore + call_id = event.data.get(ATTR_SERVICE_CALL_ID) if not self.has_service(domain, service): if event.origin == EventOrigin.local: @@ -1088,12 +1093,17 @@ class ServiceRegistry: def fire_service_executed() -> None: """Fire service executed event.""" + if not call_id: + return + + data = {ATTR_SERVICE_CALL_ID: call_id} + if (service_handler.is_coroutinefunction or service_handler.is_callback): - self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, {}, + self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, data, EventOrigin.local, event.context) else: - self._hass.bus.fire(EVENT_SERVICE_EXECUTED, {}, + self._hass.bus.fire(EVENT_SERVICE_EXECUTED, data, EventOrigin.local, event.context) try: diff --git a/tests/test_core.py b/tests/test_core.py index ce066135709f1fc30bddb398fb1d1b421a752fe8..7e6d57136e45c61caf7a415bdd7fe6c89e5888a5 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -20,9 +20,10 @@ from homeassistant.util.unit_system import (METRIC_SYSTEM) from homeassistant.const import ( __version__, EVENT_STATE_CHANGED, ATTR_FRIENDLY_NAME, CONF_UNIT_SYSTEM, ATTR_NOW, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP, - EVENT_HOMEASSISTANT_CLOSE, EVENT_SERVICE_REGISTERED, EVENT_SERVICE_REMOVED) + EVENT_HOMEASSISTANT_CLOSE, EVENT_SERVICE_REGISTERED, EVENT_SERVICE_REMOVED, + EVENT_SERVICE_EXECUTED) -from tests.common import get_test_home_assistant +from tests.common import get_test_home_assistant, async_mock_service PST = pytz.timezone('America/Los_Angeles') @@ -969,3 +970,27 @@ def test_track_task_functions(loop): assert hass._track_task finally: yield from hass.async_stop() + + +async def test_service_executed_with_subservices(hass): + """Test we block correctly till all services done.""" + calls = async_mock_service(hass, 'test', 'inner') + + async def handle_outer(call): + """Handle outer service call.""" + calls.append(call) + call1 = hass.services.async_call('test', 'inner', blocking=True, + context=call.context) + call2 = hass.services.async_call('test', 'inner', blocking=True, + context=call.context) + await asyncio.wait([call1, call2]) + calls.append(call) + + hass.services.async_register('test', 'outer', handle_outer) + + await hass.services.async_call('test', 'outer', blocking=True) + + assert len(calls) == 4 + assert [call.service for call in calls] == [ + 'outer', 'inner', 'inner', 'outer'] + assert len(hass.bus.async_listeners().get(EVENT_SERVICE_EXECUTED, [])) == 0