diff --git a/tests/common.py b/tests/common.py index d36df5091426443cdeddbcca111faddfbbc23341..684b9eb0433ce98437b525fe88ea9cfbe8b9597f 100644 --- a/tests/common.py +++ b/tests/common.py @@ -13,7 +13,12 @@ from collections.abc import ( Mapping, Sequence, ) -from contextlib import asynccontextmanager, contextmanager, suppress +from contextlib import ( + AbstractAsyncContextManager, + asynccontextmanager, + contextmanager, + suppress, +) from datetime import UTC, datetime, timedelta from enum import Enum import functools as ft @@ -177,24 +182,36 @@ def get_test_config_dir(*add_path): @contextmanager def get_test_home_assistant() -> Generator[HomeAssistant]: """Return a Home Assistant object pointing at test config directory.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - context_manager = async_test_home_assistant(loop) - hass = loop.run_until_complete(context_manager.__aenter__()) - + hass_created_event = threading.Event() loop_stop_event = threading.Event() + context_manager: AbstractAsyncContextManager = None + hass: HomeAssistant = None + loop: asyncio.AbstractEventLoop = None + orig_stop: Callable = None + def run_loop() -> None: - """Run event loop.""" + """Create and run event loop.""" + nonlocal context_manager, hass, loop, orig_stop + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + context_manager = async_test_home_assistant(loop) + hass = loop.run_until_complete(context_manager.__aenter__()) + + orig_stop = hass.stop + hass._stopped = Mock(set=loop.stop) + hass.start = start_hass + hass.stop = stop_hass loop._thread_ident = threading.get_ident() + + hass_created_event.set() + hass.loop_thread_id = loop._thread_ident loop.run_forever() loop_stop_event.set() - orig_stop = hass.stop - hass._stopped = Mock(set=loop.stop) - def start_hass(*mocks: Any) -> None: """Start hass.""" asyncio.run_coroutine_threadsafe(hass.async_start(), loop).result() @@ -204,11 +221,10 @@ def get_test_home_assistant() -> Generator[HomeAssistant]: orig_stop() loop_stop_event.wait() - hass.start = start_hass - hass.stop = stop_hass - threading.Thread(name="LoopThread", target=run_loop, daemon=False).start() + hass_created_event.wait() + try: yield hass finally: