From 76182c246d8f2c302d1b85f4443aa3d00b9bb9cc Mon Sep 17 00:00:00 2001
From: "J. Nick Koston" <nick@koston.org>
Date: Mon, 26 Aug 2024 08:37:36 -1000
Subject: [PATCH] Auto configure outbound websocket for sleepy shelly RPC
 devices (#124545)

---
 homeassistant/components/shelly/const.py      |   3 +
 .../components/shelly/coordinator.py          |  82 ++++++--
 homeassistant/components/shelly/utils.py      |  17 +-
 tests/components/shelly/conftest.py           |   1 +
 tests/components/shelly/test_config_flow.py   | 178 +++++++++++++++++-
 5 files changed, 259 insertions(+), 22 deletions(-)

diff --git a/homeassistant/components/shelly/const.py b/homeassistant/components/shelly/const.py
index 1759f4bdd18..fe4108a1f52 100644
--- a/homeassistant/components/shelly/const.py
+++ b/homeassistant/components/shelly/const.py
@@ -254,3 +254,6 @@ VIRTUAL_NUMBER_MODE_MAP = {
     "field": NumberMode.BOX,
     "slider": NumberMode.SLIDER,
 }
+
+
+API_WS_URL = "/api/shelly/ws"
diff --git a/homeassistant/components/shelly/coordinator.py b/homeassistant/components/shelly/coordinator.py
index 6286e515727..03dcdedbb6f 100644
--- a/homeassistant/components/shelly/coordinator.py
+++ b/homeassistant/components/shelly/coordinator.py
@@ -64,6 +64,7 @@ from .utils import (
     get_host,
     get_http_port,
     get_rpc_device_wakeup_period,
+    get_rpc_ws_url,
     update_device_fw_info,
 )
 
@@ -101,6 +102,9 @@ class ShellyCoordinatorBase[_DeviceT: BlockDevice | RpcDevice](
         self._pending_platforms: list[Platform] | None = None
         device_name = device.name if device.initialized else entry.title
         interval_td = timedelta(seconds=update_interval)
+        # The device has come online at least once. In the case of a sleeping RPC
+        # device, this means that the device has connected to the WS server at least once.
+        self._came_online_once = False
         super().__init__(hass, LOGGER, name=device_name, update_interval=interval_td)
 
         self._debounced_reload: Debouncer[Coroutine[Any, Any, None]] = Debouncer(
@@ -184,7 +188,7 @@ class ShellyCoordinatorBase[_DeviceT: BlockDevice | RpcDevice](
         if not self._pending_platforms:
             return True
 
-        LOGGER.debug("Device %s is online, resuming setup", self.entry.title)
+        LOGGER.debug("Device %s is online, resuming setup", self.name)
         platforms = self._pending_platforms
         self._pending_platforms = None
 
@@ -372,6 +376,7 @@ class ShellyBlockCoordinator(ShellyCoordinatorBase[BlockDevice]):
         """Handle device update."""
         LOGGER.debug("Shelly %s handle update, type: %s", self.name, update_type)
         if update_type is BlockUpdateType.ONLINE:
+            self._came_online_once = True
             self.entry.async_create_background_task(
                 self.hass,
                 self._async_device_connect_task(),
@@ -472,9 +477,24 @@ class ShellyRpcCoordinator(ShellyCoordinatorBase[RpcDevice]):
         self._event_listeners: list[Callable[[dict[str, Any]], None]] = []
         self._ota_event_listeners: list[Callable[[dict[str, Any]], None]] = []
         self._input_event_listeners: list[Callable[[dict[str, Any]], None]] = []
-
+        self._connect_task: asyncio.Task | None = None
         entry.async_on_unload(entry.add_update_listener(self._async_update_listener))
 
+    async def async_device_online(self) -> None:
+        """Handle device going online."""
+        if not self.sleep_period:
+            await self.async_request_refresh()
+        elif not self._came_online_once or not self.device.initialized:
+            LOGGER.debug(
+                "Sleepy device %s is online, trying to poll and configure", self.name
+            )
+            # Zeroconf told us the device is online, try to poll
+            # the device and if possible, set up the outbound
+            # websocket so the device will send us updates
+            # instead of relying on polling it fast enough before
+            # it goes to sleep again
+            self._async_handle_rpc_device_online()
+
     def update_sleep_period(self) -> bool:
         """Check device sleep period & update if changed."""
         if (
@@ -598,15 +618,15 @@ class ShellyRpcCoordinator(ShellyCoordinatorBase[RpcDevice]):
 
     async def _async_disconnected(self, reconnect: bool) -> None:
         """Handle device disconnected."""
-        # Sleeping devices send data and disconnect
-        # There are no disconnect events for sleeping devices
-        if self.sleep_period:
-            return
-
         async with self._connection_lock:
             if not self.connected:  # Already disconnected
                 return
             self.connected = False
+            # Sleeping devices send data and disconnect
+            # There are no disconnect events for sleeping devices
+            # but we do need to make sure self.connected is False
+            if self.sleep_period:
+                return
             self._async_run_disconnected_events()
         # Try to reconnect right away if triggered by disconnect event
         if reconnect:
@@ -645,6 +665,21 @@ class ShellyRpcCoordinator(ShellyCoordinatorBase[RpcDevice]):
         """
         if not self.sleep_period:
             await self._async_connect_ble_scanner()
+        else:
+            await self._async_setup_outbound_websocket()
+
+    async def _async_setup_outbound_websocket(self) -> None:
+        """Set up outbound websocket if it is not enabled."""
+        config = self.device.config
+        if (
+            (ws_config := config.get("ws"))
+            and (not ws_config["server"] or not ws_config["enable"])
+            and (ws_url := get_rpc_ws_url(self.hass))
+        ):
+            LOGGER.debug(
+                "Setting up outbound websocket for device %s - %s", self.name, ws_url
+            )
+            await self.device.update_outbound_websocket(ws_url)
 
     async def _async_connect_ble_scanner(self) -> None:
         """Connect BLE scanner."""
@@ -662,6 +697,21 @@ class ShellyRpcCoordinator(ShellyCoordinatorBase[RpcDevice]):
             await async_connect_scanner(self.hass, self, ble_scanner_mode)
         )
 
+    @callback
+    def _async_handle_rpc_device_online(self) -> None:
+        """Handle device going online."""
+        if self.device.connected or (
+            self._connect_task and not self._connect_task.done()
+        ):
+            LOGGER.debug("Device %s already connected/connecting", self.name)
+            return
+        self._connect_task = self.entry.async_create_background_task(
+            self.hass,
+            self._async_device_connect_task(),
+            "rpc device online",
+            eager_start=True,
+        )
+
     @callback
     def _async_handle_update(
         self, device_: RpcDevice, update_type: RpcUpdateType
@@ -669,15 +719,8 @@ class ShellyRpcCoordinator(ShellyCoordinatorBase[RpcDevice]):
         """Handle device update."""
         LOGGER.debug("Shelly %s handle update, type: %s", self.name, update_type)
         if update_type is RpcUpdateType.ONLINE:
-            if self.device.connected:
-                LOGGER.debug("Device %s already connected", self.name)
-                return
-            self.entry.async_create_background_task(
-                self.hass,
-                self._async_device_connect_task(),
-                "rpc device online",
-                eager_start=True,
-            )
+            self._came_online_once = True
+            self._async_handle_rpc_device_online()
         elif update_type is RpcUpdateType.INITIALIZED:
             self.entry.async_create_background_task(
                 self.hass, self._async_connected(), "rpc device init", eager_start=True
@@ -798,14 +841,13 @@ def get_rpc_coordinator_by_device_id(
 async def async_reconnect_soon(hass: HomeAssistant, entry: ShellyConfigEntry) -> None:
     """Try to reconnect soon."""
     if (
-        not entry.data.get(CONF_SLEEP_PERIOD)
-        and not hass.is_stopping
-        and entry.state == ConfigEntryState.LOADED
+        not hass.is_stopping
+        and entry.state is ConfigEntryState.LOADED
         and (coordinator := entry.runtime_data.rpc)
     ):
         entry.async_create_background_task(
             hass,
-            coordinator.async_request_refresh(),
+            coordinator.async_device_online(),
             "reconnect soon",
             eager_start=True,
         )
diff --git a/homeassistant/components/shelly/utils.py b/homeassistant/components/shelly/utils.py
index 339f6781171..d0a8a1230c5 100644
--- a/homeassistant/components/shelly/utils.py
+++ b/homeassistant/components/shelly/utils.py
@@ -23,6 +23,7 @@ from aioshelly.const import (
     RPC_GENERATIONS,
 )
 from aioshelly.rpc_device import RpcDevice, WsServer
+from yarl import URL
 
 from homeassistant.components import network
 from homeassistant.components.http import HomeAssistantView
@@ -36,9 +37,11 @@ from homeassistant.helpers import (
     singleton,
 )
 from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
+from homeassistant.helpers.network import NoURLAvailableError, get_url
 from homeassistant.util.dt import utcnow
 
 from .const import (
+    API_WS_URL,
     BASIC_INPUTS_EVENTS_TYPES,
     CONF_COAP_PORT,
     CONF_GEN,
@@ -254,7 +257,7 @@ class ShellyReceiver(HomeAssistantView):
     """Handle pushes from Shelly Gen2 devices."""
 
     requires_auth = False
-    url = "/api/shelly/ws"
+    url = API_WS_URL
     name = "api:shelly:ws"
 
     def __init__(self, ws_server: WsServer) -> None:
@@ -571,3 +574,15 @@ def async_remove_orphaned_virtual_entities(
 
     if orphaned_entities:
         async_remove_shelly_rpc_entities(hass, platform, mac, orphaned_entities)
+
+
+def get_rpc_ws_url(hass: HomeAssistant) -> str | None:
+    """Return the RPC websocket URL."""
+    try:
+        raw_url = get_url(hass, prefer_external=False, allow_cloud=False)
+    except NoURLAvailableError:
+        LOGGER.debug("URL not available, skipping outbound websocket setup")
+        return None
+    url = URL(raw_url)
+    ws_url = url.with_scheme("wss" if url.scheme == "https" else "ws")
+    return str(ws_url.joinpath(API_WS_URL.removeprefix("/")))
diff --git a/tests/components/shelly/conftest.py b/tests/components/shelly/conftest.py
index a2629d21362..34e4ce1379e 100644
--- a/tests/components/shelly/conftest.py
+++ b/tests/components/shelly/conftest.py
@@ -186,6 +186,7 @@ MOCK_CONFIG = {
         "device": {"name": "Test name"},
     },
     "wifi": {"sta": {"enable": True}, "sta1": {"enable": False}},
+    "ws": {"enable": False, "server": None},
 }
 
 MOCK_SHELLY_COAP = {
diff --git a/tests/components/shelly/test_config_flow.py b/tests/components/shelly/test_config_flow.py
index 0c574a33e0c..c0c089f469a 100644
--- a/tests/components/shelly/test_config_flow.py
+++ b/tests/components/shelly/test_config_flow.py
@@ -4,7 +4,7 @@ from dataclasses import replace
 from datetime import timedelta
 from ipaddress import ip_address
 from typing import Any
-from unittest.mock import AsyncMock, Mock, patch
+from unittest.mock import AsyncMock, Mock, call, patch
 
 from aioshelly.const import DEFAULT_HTTP_PORT, MODEL_1, MODEL_PLUS_2PM
 from aioshelly.exceptions import (
@@ -1153,6 +1153,182 @@ async def test_zeroconf_sleeping_device_not_triggers_refresh(
     assert "device did not update" not in caplog.text
 
 
+async def test_zeroconf_sleeping_device_attempts_configure(
+    hass: HomeAssistant,
+    mock_rpc_device: Mock,
+    monkeypatch: pytest.MonkeyPatch,
+    caplog: pytest.LogCaptureFixture,
+) -> None:
+    """Test zeroconf discovery configures a sleeping device outbound websocket."""
+    monkeypatch.setattr(mock_rpc_device, "connected", False)
+    monkeypatch.setattr(mock_rpc_device, "initialized", False)
+    monkeypatch.setitem(mock_rpc_device.status["sys"], "wakeup_period", 1000)
+    entry = MockConfigEntry(
+        domain="shelly",
+        unique_id="AABBCCDDEEFF",
+        data={"host": "1.1.1.1", "gen": 2, "sleep_period": 1000, "model": MODEL_1},
+    )
+    entry.add_to_hass(hass)
+    await hass.config_entries.async_setup(entry.entry_id)
+    await hass.async_block_till_done()
+    mock_rpc_device.mock_disconnected()
+    await hass.async_block_till_done()
+
+    mock_rpc_device.mock_online()
+    await hass.async_block_till_done(wait_background_tasks=True)
+
+    assert "online, resuming setup" in caplog.text
+    assert len(mock_rpc_device.initialize.mock_calls) == 1
+
+    with patch(
+        "homeassistant.components.shelly.config_flow.get_info",
+        return_value={"mac": "AABBCCDDEEFF", "type": MODEL_1, "auth": False},
+    ):
+        result = await hass.config_entries.flow.async_init(
+            DOMAIN,
+            data=DISCOVERY_INFO,
+            context={"source": config_entries.SOURCE_ZEROCONF},
+        )
+        assert result["type"] is FlowResultType.ABORT
+        assert result["reason"] == "already_configured"
+
+    assert mock_rpc_device.update_outbound_websocket.mock_calls == []
+
+    monkeypatch.setattr(mock_rpc_device, "connected", True)
+    monkeypatch.setattr(mock_rpc_device, "initialized", True)
+    mock_rpc_device.mock_initialized()
+    async_fire_time_changed(
+        hass, dt_util.utcnow() + timedelta(seconds=ENTRY_RELOAD_COOLDOWN)
+    )
+    await hass.async_block_till_done()
+    assert "device did not update" not in caplog.text
+
+    monkeypatch.setattr(mock_rpc_device, "connected", False)
+    mock_rpc_device.mock_disconnected()
+    assert mock_rpc_device.update_outbound_websocket.mock_calls == [
+        call("ws://10.10.10.10:8123/api/shelly/ws")
+    ]
+
+
+async def test_zeroconf_sleeping_device_attempts_configure_ws_disabled(
+    hass: HomeAssistant,
+    mock_rpc_device: Mock,
+    monkeypatch: pytest.MonkeyPatch,
+    caplog: pytest.LogCaptureFixture,
+) -> None:
+    """Test zeroconf discovery configures a sleeping device outbound websocket when its disabled."""
+    monkeypatch.setattr(mock_rpc_device, "connected", False)
+    monkeypatch.setattr(mock_rpc_device, "initialized", False)
+    monkeypatch.setitem(mock_rpc_device.status["sys"], "wakeup_period", 1000)
+    monkeypatch.setitem(
+        mock_rpc_device.config, "ws", {"enable": False, "server": "ws://oldha"}
+    )
+    entry = MockConfigEntry(
+        domain="shelly",
+        unique_id="AABBCCDDEEFF",
+        data={"host": "1.1.1.1", "gen": 2, "sleep_period": 1000, "model": MODEL_1},
+    )
+    entry.add_to_hass(hass)
+    await hass.config_entries.async_setup(entry.entry_id)
+    await hass.async_block_till_done()
+    mock_rpc_device.mock_disconnected()
+    await hass.async_block_till_done()
+
+    mock_rpc_device.mock_online()
+    await hass.async_block_till_done(wait_background_tasks=True)
+
+    assert "online, resuming setup" in caplog.text
+    assert len(mock_rpc_device.initialize.mock_calls) == 1
+
+    with patch(
+        "homeassistant.components.shelly.config_flow.get_info",
+        return_value={"mac": "AABBCCDDEEFF", "type": MODEL_1, "auth": False},
+    ):
+        result = await hass.config_entries.flow.async_init(
+            DOMAIN,
+            data=DISCOVERY_INFO,
+            context={"source": config_entries.SOURCE_ZEROCONF},
+        )
+        assert result["type"] is FlowResultType.ABORT
+        assert result["reason"] == "already_configured"
+
+    assert mock_rpc_device.update_outbound_websocket.mock_calls == []
+
+    monkeypatch.setattr(mock_rpc_device, "connected", True)
+    monkeypatch.setattr(mock_rpc_device, "initialized", True)
+    mock_rpc_device.mock_initialized()
+    async_fire_time_changed(
+        hass, dt_util.utcnow() + timedelta(seconds=ENTRY_RELOAD_COOLDOWN)
+    )
+    await hass.async_block_till_done()
+    assert "device did not update" not in caplog.text
+
+    monkeypatch.setattr(mock_rpc_device, "connected", False)
+    mock_rpc_device.mock_disconnected()
+    assert mock_rpc_device.update_outbound_websocket.mock_calls == [
+        call("ws://10.10.10.10:8123/api/shelly/ws")
+    ]
+
+
+async def test_zeroconf_sleeping_device_attempts_configure_no_url_available(
+    hass: HomeAssistant,
+    mock_rpc_device: Mock,
+    monkeypatch: pytest.MonkeyPatch,
+    caplog: pytest.LogCaptureFixture,
+) -> None:
+    """Test zeroconf discovery for sleeping device with no hass url."""
+    hass.config.internal_url = None
+    hass.config.external_url = None
+    hass.config.api = None
+    monkeypatch.setattr(mock_rpc_device, "connected", False)
+    monkeypatch.setattr(mock_rpc_device, "initialized", False)
+    monkeypatch.setitem(mock_rpc_device.status["sys"], "wakeup_period", 1000)
+    entry = MockConfigEntry(
+        domain="shelly",
+        unique_id="AABBCCDDEEFF",
+        data={"host": "1.1.1.1", "gen": 2, "sleep_period": 1000, "model": MODEL_1},
+    )
+    entry.add_to_hass(hass)
+    await hass.config_entries.async_setup(entry.entry_id)
+    await hass.async_block_till_done()
+    mock_rpc_device.mock_disconnected()
+    await hass.async_block_till_done()
+
+    mock_rpc_device.mock_online()
+    await hass.async_block_till_done(wait_background_tasks=True)
+
+    assert "online, resuming setup" in caplog.text
+    assert len(mock_rpc_device.initialize.mock_calls) == 1
+
+    with patch(
+        "homeassistant.components.shelly.config_flow.get_info",
+        return_value={"mac": "AABBCCDDEEFF", "type": MODEL_1, "auth": False},
+    ):
+        result = await hass.config_entries.flow.async_init(
+            DOMAIN,
+            data=DISCOVERY_INFO,
+            context={"source": config_entries.SOURCE_ZEROCONF},
+        )
+        assert result["type"] is FlowResultType.ABORT
+        assert result["reason"] == "already_configured"
+
+    assert mock_rpc_device.update_outbound_websocket.mock_calls == []
+
+    monkeypatch.setattr(mock_rpc_device, "connected", True)
+    monkeypatch.setattr(mock_rpc_device, "initialized", True)
+    mock_rpc_device.mock_initialized()
+    async_fire_time_changed(
+        hass, dt_util.utcnow() + timedelta(seconds=ENTRY_RELOAD_COOLDOWN)
+    )
+    await hass.async_block_till_done()
+    assert "device did not update" not in caplog.text
+
+    monkeypatch.setattr(mock_rpc_device, "connected", False)
+    mock_rpc_device.mock_disconnected()
+    # No url available so no attempt to configure the device
+    assert mock_rpc_device.update_outbound_websocket.mock_calls == []
+
+
 async def test_sleeping_device_gen2_with_new_firmware(
     hass: HomeAssistant, mock_rpc_device: Mock, monkeypatch: pytest.MonkeyPatch
 ) -> None:
-- 
GitLab