From 18ee9f4725b6058a2b5d736c61554d5bd0be80ac Mon Sep 17 00:00:00 2001
From: Jan Bouwhuis <jbouwh@users.noreply.github.com>
Date: Fri, 7 Jul 2023 20:52:38 +0200
Subject: [PATCH] Refactor async_get_hass to rely on threading.local instead of
 a ContextVar (#96005)

* Test for async_get_hass

* Add Fix
---
 homeassistant/core.py                      |  22 ++-
 homeassistant/helpers/config_validation.py |   8 +-
 tests/conftest.py                          |  12 +-
 tests/helpers/test_config_validation.py    |   7 +-
 tests/test_core.py                         | 181 +++++++++++++++++++++
 5 files changed, 205 insertions(+), 25 deletions(-)

diff --git a/homeassistant/core.py b/homeassistant/core.py
index dbc8769bb6f..82ea7228157 100644
--- a/homeassistant/core.py
+++ b/homeassistant/core.py
@@ -16,7 +16,6 @@ from collections.abc import (
 )
 import concurrent.futures
 from contextlib import suppress
-from contextvars import ContextVar
 import datetime
 import enum
 import functools
@@ -155,8 +154,6 @@ MAX_EXPECTED_ENTITY_IDS = 16384
 
 _LOGGER = logging.getLogger(__name__)
 
-_cv_hass: ContextVar[HomeAssistant] = ContextVar("hass")
-
 
 @functools.lru_cache(MAX_EXPECTED_ENTITY_IDS)
 def split_entity_id(entity_id: str) -> tuple[str, str]:
@@ -199,16 +196,27 @@ def is_callback(func: Callable[..., Any]) -> bool:
     return getattr(func, "_hass_callback", False) is True
 
 
+class _Hass(threading.local):
+    """Container which makes a HomeAssistant instance available to the event loop."""
+
+    hass: HomeAssistant | None = None
+
+
+_hass = _Hass()
+
+
 @callback
 def async_get_hass() -> HomeAssistant:
     """Return the HomeAssistant instance.
 
-    Raises LookupError if no HomeAssistant instance is available.
+    Raises HomeAssistantError when called from the wrong thread.
 
     This should be used where it's very cumbersome or downright impossible to pass
     hass to the code which needs it.
     """
-    return _cv_hass.get()
+    if not _hass.hass:
+        raise HomeAssistantError("async_get_hass called from the wrong thread")
+    return _hass.hass
 
 
 @enum.unique
@@ -292,9 +300,9 @@ class HomeAssistant:
     config_entries: ConfigEntries = None  # type: ignore[assignment]
 
     def __new__(cls) -> HomeAssistant:
-        """Set the _cv_hass context variable."""
+        """Set the _hass thread local data."""
         hass = super().__new__(cls)
-        _cv_hass.set(hass)
+        _hass.hass = hass
         return hass
 
     def __init__(self) -> None:
diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py
index cea8a866f5c..e8f1e58615c 100644
--- a/homeassistant/helpers/config_validation.py
+++ b/homeassistant/helpers/config_validation.py
@@ -93,7 +93,7 @@ from homeassistant.core import (
     split_entity_id,
     valid_entity_id,
 )
-from homeassistant.exceptions import TemplateError
+from homeassistant.exceptions import HomeAssistantError, TemplateError
 from homeassistant.generated import currencies
 from homeassistant.generated.countries import COUNTRIES
 from homeassistant.generated.languages import LANGUAGES
@@ -609,7 +609,7 @@ def template(value: Any | None) -> template_helper.Template:
         raise vol.Invalid("template value should be a string")
 
     hass: HomeAssistant | None = None
-    with contextlib.suppress(LookupError):
+    with contextlib.suppress(HomeAssistantError):
         hass = async_get_hass()
 
     template_value = template_helper.Template(str(value), hass)
@@ -631,7 +631,7 @@ def dynamic_template(value: Any | None) -> template_helper.Template:
         raise vol.Invalid("template value does not contain a dynamic template")
 
     hass: HomeAssistant | None = None
-    with contextlib.suppress(LookupError):
+    with contextlib.suppress(HomeAssistantError):
         hass = async_get_hass()
 
     template_value = template_helper.Template(str(value), hass)
@@ -1098,7 +1098,7 @@ def _no_yaml_config_schema(
         # pylint: disable-next=import-outside-toplevel
         from .issue_registry import IssueSeverity, async_create_issue
 
-        with contextlib.suppress(LookupError):
+        with contextlib.suppress(HomeAssistantError):
             hass = async_get_hass()
             async_create_issue(
                 hass,
diff --git a/tests/conftest.py b/tests/conftest.py
index 56014d7a556..922e42c7a7e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -490,17 +490,7 @@ def hass_fixture_setup() -> list[bool]:
 
 
 @pytest.fixture
-def hass(_hass: HomeAssistant) -> HomeAssistant:
-    """Fixture to provide a test instance of Home Assistant."""
-    # This wraps the async _hass fixture inside a sync fixture, to ensure
-    # the `hass` context variable is set in the execution context in which
-    # the test itself is executed
-    ha._cv_hass.set(_hass)
-    return _hass
-
-
-@pytest.fixture
-async def _hass(
+async def hass(
     hass_fixture_setup: list[bool],
     event_loop: asyncio.AbstractEventLoop,
     load_registries: bool,
diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py
index 458774b748c..5ea6df42349 100644
--- a/tests/helpers/test_config_validation.py
+++ b/tests/helpers/test_config_validation.py
@@ -12,6 +12,7 @@ import voluptuous as vol
 
 import homeassistant
 from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant
+from homeassistant.exceptions import HomeAssistantError
 from homeassistant.helpers import (
     config_validation as cv,
     issue_registry as ir,
@@ -383,7 +384,7 @@ def test_service() -> None:
     schema("homeassistant.turn_on")
 
 
-def test_service_schema() -> None:
+def test_service_schema(hass: HomeAssistant) -> None:
     """Test service_schema validation."""
     options = (
         {},
@@ -1550,10 +1551,10 @@ def test_config_entry_only_schema_cant_find_module() -> None:
 def test_config_entry_only_schema_no_hass(
     hass: HomeAssistant, caplog: pytest.LogCaptureFixture
 ) -> None:
-    """Test if the the hass context var is not set in our context."""
+    """Test if the the hass context is not set in our context."""
     with patch(
         "homeassistant.helpers.config_validation.async_get_hass",
-        side_effect=LookupError,
+        side_effect=HomeAssistantError,
     ):
         cv.config_entry_only_config_schema("test_domain")(
             {"test_domain": {"foo": "bar"}}
diff --git a/tests/test_core.py b/tests/test_core.py
index 8b63eab7b42..7e0766c8ac5 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -9,10 +9,12 @@ import gc
 import logging
 import os
 from tempfile import TemporaryDirectory
+import threading
 import time
 from typing import Any
 from unittest.mock import MagicMock, Mock, PropertyMock, patch
 
+import async_timeout
 import pytest
 import voluptuous as vol
 
@@ -40,6 +42,7 @@ from homeassistant.core import (
     ServiceResponse,
     State,
     SupportsResponse,
+    callback,
 )
 from homeassistant.exceptions import (
     HomeAssistantError,
@@ -202,6 +205,184 @@ def test_async_run_hass_job_delegates_non_async() -> None:
     assert len(hass.async_add_hass_job.mock_calls) == 1
 
 
+async def test_async_get_hass_can_be_called(hass: HomeAssistant) -> None:
+    """Test calling async_get_hass via different paths.
+
+    The test asserts async_get_hass can be called from:
+    - Coroutines and callbacks
+    - Callbacks scheduled from callbacks, coroutines and threads
+    - Coroutines scheduled from callbacks, coroutines and threads
+
+    The test also asserts async_get_hass can not be called from threads
+    other than the event loop.
+    """
+    task_finished = asyncio.Event()
+
+    def can_call_async_get_hass() -> bool:
+        """Test if it's possible to call async_get_hass."""
+        try:
+            if ha.async_get_hass() is hass:
+                return True
+            raise Exception
+        except HomeAssistantError:
+            return False
+
+        raise Exception
+
+    # Test scheduling a coroutine which calls async_get_hass via hass.async_create_task
+    async def _async_create_task() -> None:
+        task_finished.set()
+        assert can_call_async_get_hass()
+
+    hass.async_create_task(_async_create_task(), "create_task")
+    async with async_timeout.timeout(1):
+        await task_finished.wait()
+    task_finished.clear()
+
+    # Test scheduling a callback which calls async_get_hass via hass.async_add_job
+    @callback
+    def _add_job() -> None:
+        assert can_call_async_get_hass()
+        task_finished.set()
+
+    hass.async_add_job(_add_job)
+    async with async_timeout.timeout(1):
+        await task_finished.wait()
+    task_finished.clear()
+
+    # Test scheduling a callback which calls async_get_hass from a callback
+    @callback
+    def _schedule_callback_from_callback() -> None:
+        @callback
+        def _callback():
+            assert can_call_async_get_hass()
+            task_finished.set()
+
+        # Test the scheduled callback itself can call async_get_hass
+        assert can_call_async_get_hass()
+        hass.async_add_job(_callback)
+
+    _schedule_callback_from_callback()
+    async with async_timeout.timeout(1):
+        await task_finished.wait()
+    task_finished.clear()
+
+    # Test scheduling a coroutine which calls async_get_hass from a callback
+    @callback
+    def _schedule_coroutine_from_callback() -> None:
+        async def _coroutine():
+            assert can_call_async_get_hass()
+            task_finished.set()
+
+        # Test the scheduled callback itself can call async_get_hass
+        assert can_call_async_get_hass()
+        hass.async_add_job(_coroutine())
+
+    _schedule_coroutine_from_callback()
+    async with async_timeout.timeout(1):
+        await task_finished.wait()
+    task_finished.clear()
+
+    # Test scheduling a callback which calls async_get_hass from a coroutine
+    async def _schedule_callback_from_coroutine() -> None:
+        @callback
+        def _callback():
+            assert can_call_async_get_hass()
+            task_finished.set()
+
+        # Test the coroutine itself can call async_get_hass
+        assert can_call_async_get_hass()
+        hass.async_add_job(_callback)
+
+    await _schedule_callback_from_coroutine()
+    async with async_timeout.timeout(1):
+        await task_finished.wait()
+    task_finished.clear()
+
+    # Test scheduling a coroutine which calls async_get_hass from a coroutine
+    async def _schedule_callback_from_coroutine() -> None:
+        async def _coroutine():
+            assert can_call_async_get_hass()
+            task_finished.set()
+
+        # Test the coroutine itself can call async_get_hass
+        assert can_call_async_get_hass()
+        await hass.async_create_task(_coroutine())
+
+    await _schedule_callback_from_coroutine()
+    async with async_timeout.timeout(1):
+        await task_finished.wait()
+    task_finished.clear()
+
+    # Test scheduling a callback which calls async_get_hass from an executor
+    def _async_add_executor_job_add_job() -> None:
+        @callback
+        def _async_add_job():
+            assert can_call_async_get_hass()
+            task_finished.set()
+
+        # Test the executor itself can not call async_get_hass
+        assert not can_call_async_get_hass()
+        hass.add_job(_async_add_job)
+
+    await hass.async_add_executor_job(_async_add_executor_job_add_job)
+    async with async_timeout.timeout(1):
+        await task_finished.wait()
+    task_finished.clear()
+
+    # Test scheduling a coroutine which calls async_get_hass from an executor
+    def _async_add_executor_job_create_task() -> None:
+        async def _async_create_task() -> None:
+            assert can_call_async_get_hass()
+            task_finished.set()
+
+        # Test the executor itself can not call async_get_hass
+        assert not can_call_async_get_hass()
+        hass.create_task(_async_create_task())
+
+    await hass.async_add_executor_job(_async_add_executor_job_create_task)
+    async with async_timeout.timeout(1):
+        await task_finished.wait()
+    task_finished.clear()
+
+    # Test scheduling a callback which calls async_get_hass from a worker thread
+    class MyJobAddJob(threading.Thread):
+        @callback
+        def _my_threaded_job_add_job(self) -> None:
+            assert can_call_async_get_hass()
+            task_finished.set()
+
+        def run(self) -> None:
+            # Test the worker thread itself can not call async_get_hass
+            assert not can_call_async_get_hass()
+            hass.add_job(self._my_threaded_job_add_job)
+
+    my_job_add_job = MyJobAddJob()
+    my_job_add_job.start()
+    async with async_timeout.timeout(1):
+        await task_finished.wait()
+    task_finished.clear()
+    my_job_add_job.join()
+
+    # Test scheduling a coroutine which calls async_get_hass from a worker thread
+    class MyJobCreateTask(threading.Thread):
+        async def _my_threaded_job_create_task(self) -> None:
+            assert can_call_async_get_hass()
+            task_finished.set()
+
+        def run(self) -> None:
+            # Test the worker thread itself can not call async_get_hass
+            assert not can_call_async_get_hass()
+            hass.create_task(self._my_threaded_job_create_task())
+
+    my_job_create_task = MyJobCreateTask()
+    my_job_create_task.start()
+    async with async_timeout.timeout(1):
+        await task_finished.wait()
+    task_finished.clear()
+    my_job_create_task.join()
+
+
 async def test_stage_shutdown(hass: HomeAssistant) -> None:
     """Simulate a shutdown, test calling stuff."""
     test_stop = async_capture_events(hass, EVENT_HOMEASSISTANT_STOP)
-- 
GitLab