From 61ff1b786bf6d6341b7c40895fa478058532223f Mon Sep 17 00:00:00 2001
From: Erik Montnemery <erik@montnemery.com>
Date: Mon, 22 Aug 2022 15:58:01 +0200
Subject: [PATCH] Add a context variable holding a HomeAssistant reference
 (#76303)

* Add a context variable holding a HomeAssistant reference

* Move variable setup and update test

* Refactor

* Revert "Refactor"

This reverts commit 346d005ee67b9e27e05363d04a7f48eaf416a16b.

* Set context variable when creating HomeAssistant object

* Update docstring

* Update docstring

Co-authored-by: jbouwh <jan@jbsoft.nl>
---
 homeassistant/core.py   | 21 +++++++++++++++++++++
 tests/test_bootstrap.py |  2 ++
 2 files changed, 23 insertions(+)

diff --git a/homeassistant/core.py b/homeassistant/core.py
index fcd41ddc856..01c75fb707e 100644
--- a/homeassistant/core.py
+++ b/homeassistant/core.py
@@ -15,6 +15,7 @@ from collections.abc import (
     Iterable,
     Mapping,
 )
+from contextvars import ContextVar
 import datetime
 import enum
 import functools
@@ -138,6 +139,8 @@ MAX_EXPECTED_ENTITY_IDS = 16384
 
 _LOGGER = logging.getLogger(__name__)
 
+_cv_hass: ContextVar[HomeAssistant] = ContextVar("current_entry")
+
 
 @functools.lru_cache(MAX_EXPECTED_ENTITY_IDS)
 def split_entity_id(entity_id: str) -> tuple[str, str]:
@@ -175,6 +178,18 @@ def is_callback(func: Callable[..., Any]) -> bool:
     return getattr(func, "_hass_callback", False) is True
 
 
+@callback
+def async_get_hass() -> HomeAssistant:
+    """Return the HomeAssistant instance.
+
+    Raises LookupError if no HomeAssistant instance is available.
+
+    This should be used where it's very cumbersome or downright impossible to pass
+    hass to the code which needs it.
+    """
+    return _cv_hass.get()
+
+
 @enum.unique
 class HassJobType(enum.Enum):
     """Represent a job type."""
@@ -242,6 +257,12 @@ class HomeAssistant:
     http: HomeAssistantHTTP = None  # type: ignore[assignment]
     config_entries: ConfigEntries = None  # type: ignore[assignment]
 
+    def __new__(cls) -> HomeAssistant:
+        """Set the _cv_hass context variable."""
+        hass = super().__new__(cls)
+        _cv_hass.set(hass)
+        return hass
+
     def __init__(self) -> None:
         """Initialize new Home Assistant object."""
         self.loop = asyncio.get_running_loop()
diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py
index 06f800af7f3..56c15f49337 100644
--- a/tests/test_bootstrap.py
+++ b/tests/test_bootstrap.py
@@ -501,6 +501,8 @@ async def test_setup_hass(
     assert len(mock_ensure_config_exists.mock_calls) == 1
     assert len(mock_process_ha_config_upgrade.mock_calls) == 1
 
+    assert hass == core.async_get_hass()
+
 
 async def test_setup_hass_takes_longer_than_log_slow_startup(
     mock_enable_logging,
-- 
GitLab