From bf1b408038a0b0bf3201d3752e17b4db6fbabe69 Mon Sep 17 00:00:00 2001
From: Phil Bruckner <pnbruckner@gmail.com>
Date: Sat, 4 Apr 2020 17:36:33 -0500
Subject: [PATCH] Handle cancellation in ServiceRegistry.async_call (#33644)

---
 homeassistant/core.py | 63 +++++++++++++++++++++++++++++++------------
 tests/test_core.py    | 36 +++++++++++++++++++++++++
 2 files changed, 82 insertions(+), 17 deletions(-)

diff --git a/homeassistant/core.py b/homeassistant/core.py
index 54f1c1cd366..7e85e7616a8 100644
--- a/homeassistant/core.py
+++ b/homeassistant/core.py
@@ -28,6 +28,7 @@ from typing import (
     Optional,
     Set,
     TypeVar,
+    Union,
 )
 import uuid
 
@@ -1224,29 +1225,57 @@ class ServiceRegistry:
             context=context,
         )
 
+        coro = self._execute_service(handler, service_call)
         if not blocking:
-            self._hass.async_create_task(self._safe_execute(handler, service_call))
+            self._run_service_in_background(coro, service_call)
             return None
 
+        task = self._hass.async_create_task(coro)
         try:
-            async with timeout(limit):
-                await asyncio.shield(self._execute_service(handler, service_call))
+            await asyncio.wait({task}, timeout=limit)
+        except asyncio.CancelledError:
+            # Task calling us was cancelled, so cancel service call task, and wait for
+            # it to be cancelled, within reason, before leaving.
+            _LOGGER.debug("Service call was cancelled: %s", service_call)
+            task.cancel()
+            await asyncio.wait({task}, timeout=SERVICE_CALL_LIMIT)
+            raise
+
+        if task.cancelled():
+            # Service call task was cancelled some other way, such as during shutdown.
+            _LOGGER.debug("Service was cancelled: %s", service_call)
+            raise asyncio.CancelledError
+        if task.done():
+            # Propagate any exceptions that might have happened during service call.
+            task.result()
+            # Service call completed successfully!
             return True
-        except asyncio.TimeoutError:
-            return False
+        # Service call task did not complete before timeout expired.
+        # Let it keep running in background.
+        self._run_service_in_background(task, service_call)
+        _LOGGER.debug("Service did not complete before timeout: %s", service_call)
+        return False
 
-    async def _safe_execute(self, handler: Service, service_call: ServiceCall) -> None:
-        """Execute a service and catch exceptions."""
-        try:
-            await self._execute_service(handler, service_call)
-        except Unauthorized:
-            _LOGGER.warning(
-                "Unauthorized service called %s/%s",
-                service_call.domain,
-                service_call.service,
-            )
-        except Exception:  # pylint: disable=broad-except
-            _LOGGER.exception("Error executing service %s", service_call)
+    def _run_service_in_background(
+        self, coro_or_task: Union[Coroutine, asyncio.Task], service_call: ServiceCall
+    ) -> None:
+        """Run service call in background, catching and logging any exceptions."""
+
+        async def catch_exceptions() -> None:
+            try:
+                await coro_or_task
+            except Unauthorized:
+                _LOGGER.warning(
+                    "Unauthorized service called %s/%s",
+                    service_call.domain,
+                    service_call.service,
+                )
+            except asyncio.CancelledError:
+                _LOGGER.debug("Service was cancelled: %s", service_call)
+            except Exception:  # pylint: disable=broad-except
+                _LOGGER.exception("Error executing service: %s", service_call)
+
+        self._hass.async_create_task(catch_exceptions())
 
     async def _execute_service(
         self, handler: Service, service_call: ServiceCall
diff --git a/tests/test_core.py b/tests/test_core.py
index 5e6bb090821..deeb808396a 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -1214,6 +1214,42 @@ async def test_async_functions_with_callback(hass):
     assert len(runs) == 3
 
 
+@pytest.mark.parametrize("cancel_call", [True, False])
+async def test_cancel_service_task(hass, cancel_call):
+    """Test cancellation."""
+    service_called = asyncio.Event()
+    service_cancelled = False
+
+    async def service_handler(call):
+        nonlocal service_cancelled
+        service_called.set()
+        try:
+            await asyncio.sleep(10)
+        except asyncio.CancelledError:
+            service_cancelled = True
+            raise
+
+    hass.services.async_register("test_domain", "test_service", service_handler)
+    call_task = hass.async_create_task(
+        hass.services.async_call("test_domain", "test_service", blocking=True)
+    )
+
+    tasks_1 = asyncio.all_tasks()
+    await asyncio.wait_for(service_called.wait(), timeout=1)
+    tasks_2 = asyncio.all_tasks() - tasks_1
+    assert len(tasks_2) == 1
+    service_task = tasks_2.pop()
+
+    if cancel_call:
+        call_task.cancel()
+    else:
+        service_task.cancel()
+    with pytest.raises(asyncio.CancelledError):
+        await call_task
+
+    assert service_cancelled
+
+
 def test_valid_entity_id():
     """Test valid entity ID."""
     for invalid in [
-- 
GitLab