From bf1b408038a0b0bf3201d3752e17b4db6fbabe69 Mon Sep 17 00:00:00 2001 From: Phil Bruckner <pnbruckner@gmail.com> Date: Sat, 4 Apr 2020 17:36:33 -0500 Subject: [PATCH] Handle cancellation in ServiceRegistry.async_call (#33644) --- homeassistant/core.py | 63 +++++++++++++++++++++++++++++++------------ tests/test_core.py | 36 +++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 17 deletions(-) diff --git a/homeassistant/core.py b/homeassistant/core.py index 54f1c1cd366..7e85e7616a8 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -28,6 +28,7 @@ from typing import ( Optional, Set, TypeVar, + Union, ) import uuid @@ -1224,29 +1225,57 @@ class ServiceRegistry: context=context, ) + coro = self._execute_service(handler, service_call) if not blocking: - self._hass.async_create_task(self._safe_execute(handler, service_call)) + self._run_service_in_background(coro, service_call) return None + task = self._hass.async_create_task(coro) try: - async with timeout(limit): - await asyncio.shield(self._execute_service(handler, service_call)) + await asyncio.wait({task}, timeout=limit) + except asyncio.CancelledError: + # Task calling us was cancelled, so cancel service call task, and wait for + # it to be cancelled, within reason, before leaving. + _LOGGER.debug("Service call was cancelled: %s", service_call) + task.cancel() + await asyncio.wait({task}, timeout=SERVICE_CALL_LIMIT) + raise + + if task.cancelled(): + # Service call task was cancelled some other way, such as during shutdown. + _LOGGER.debug("Service was cancelled: %s", service_call) + raise asyncio.CancelledError + if task.done(): + # Propagate any exceptions that might have happened during service call. + task.result() + # Service call completed successfully! return True - except asyncio.TimeoutError: - return False + # Service call task did not complete before timeout expired. + # Let it keep running in background. + self._run_service_in_background(task, service_call) + _LOGGER.debug("Service did not complete before timeout: %s", service_call) + return False - async def _safe_execute(self, handler: Service, service_call: ServiceCall) -> None: - """Execute a service and catch exceptions.""" - try: - await self._execute_service(handler, service_call) - except Unauthorized: - _LOGGER.warning( - "Unauthorized service called %s/%s", - service_call.domain, - service_call.service, - ) - except Exception: # pylint: disable=broad-except - _LOGGER.exception("Error executing service %s", service_call) + def _run_service_in_background( + self, coro_or_task: Union[Coroutine, asyncio.Task], service_call: ServiceCall + ) -> None: + """Run service call in background, catching and logging any exceptions.""" + + async def catch_exceptions() -> None: + try: + await coro_or_task + except Unauthorized: + _LOGGER.warning( + "Unauthorized service called %s/%s", + service_call.domain, + service_call.service, + ) + except asyncio.CancelledError: + _LOGGER.debug("Service was cancelled: %s", service_call) + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Error executing service: %s", service_call) + + self._hass.async_create_task(catch_exceptions()) async def _execute_service( self, handler: Service, service_call: ServiceCall diff --git a/tests/test_core.py b/tests/test_core.py index 5e6bb090821..deeb808396a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1214,6 +1214,42 @@ async def test_async_functions_with_callback(hass): assert len(runs) == 3 +@pytest.mark.parametrize("cancel_call", [True, False]) +async def test_cancel_service_task(hass, cancel_call): + """Test cancellation.""" + service_called = asyncio.Event() + service_cancelled = False + + async def service_handler(call): + nonlocal service_cancelled + service_called.set() + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + service_cancelled = True + raise + + hass.services.async_register("test_domain", "test_service", service_handler) + call_task = hass.async_create_task( + hass.services.async_call("test_domain", "test_service", blocking=True) + ) + + tasks_1 = asyncio.all_tasks() + await asyncio.wait_for(service_called.wait(), timeout=1) + tasks_2 = asyncio.all_tasks() - tasks_1 + assert len(tasks_2) == 1 + service_task = tasks_2.pop() + + if cancel_call: + call_task.cancel() + else: + service_task.cancel() + with pytest.raises(asyncio.CancelledError): + await call_task + + assert service_cancelled + + def test_valid_entity_id(): """Test valid entity ID.""" for invalid in [ -- GitLab