From 5694e4190cddcab365cb0eb10201e5e1e193a9f0 Mon Sep 17 00:00:00 2001
From: Erik Montnemery <erik@montnemery.com>
Date: Mon, 9 Nov 2020 19:47:45 +0100
Subject: [PATCH] Extend WS API result when enabling an entity (#42667)

* Extend WS API result when enabling an entity

* Fix tests

* Fix tests

* Move updated registry entry to sub dict

* Address review comments

* Increase test coverage
---
 .../components/config/entity_registry.py      |  14 +-
 homeassistant/config_entries.py               |   6 +-
 .../components/config/test_entity_registry.py | 246 ++++++++++++++----
 tests/components/tasmota/test_sensor.py       |   6 +-
 tests/test_config_entries.py                  |   6 +-
 5 files changed, 212 insertions(+), 66 deletions(-)

diff --git a/homeassistant/components/config/entity_registry.py b/homeassistant/components/config/entity_registry.py
index 6e4fa30e205..73327ecf23c 100644
--- a/homeassistant/components/config/entity_registry.py
+++ b/homeassistant/components/config/entity_registry.py
@@ -1,6 +1,7 @@
 """HTTP views to interact with the entity registry."""
 import voluptuous as vol
 
+from homeassistant import config_entries
 from homeassistant.components import websocket_api
 from homeassistant.components.websocket_api.const import ERR_NOT_FOUND
 from homeassistant.components.websocket_api.decorators import (
@@ -113,10 +114,15 @@ async def websocket_update_entity(hass, connection, msg):
         connection.send_message(
             websocket_api.error_message(msg["id"], "invalid_info", str(err))
         )
-    else:
-        connection.send_message(
-            websocket_api.result_message(msg["id"], _entry_ext_dict(entry))
-        )
+        return
+    result = {"entity_entry": _entry_ext_dict(entry)}
+    if "disabled_by" in changes and changes["disabled_by"] is None:
+        config_entry = hass.config_entries.async_get_entry(entry.config_entry_id)
+        if config_entry and not config_entry.supports_unload:
+            result["require_restart"] = True
+        else:
+            result["reload_delay"] = config_entries.RELOAD_AFTER_UPDATE_DELAY
+    connection.send_result(msg["id"], result)
 
 
 @require_admin
diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py
index 7fa343146d2..af82db0ffbb 100644
--- a/homeassistant/config_entries.py
+++ b/homeassistant/config_entries.py
@@ -89,6 +89,8 @@ CONN_CLASS_LOCAL_POLL = "local_poll"
 CONN_CLASS_ASSUMED = "assumed"
 CONN_CLASS_UNKNOWN = "unknown"
 
+RELOAD_AFTER_UPDATE_DELAY = 30
+
 
 class ConfigError(HomeAssistantError):
     """Error while configuring an account."""
@@ -1112,8 +1114,6 @@ class SystemOptions:
 class EntityRegistryDisabledHandler:
     """Handler to handle when entities related to config entries updating disabled_by."""
 
-    RELOAD_AFTER_UPDATE_DELAY = 30
-
     def __init__(self, hass: HomeAssistant) -> None:
         """Initialize the handler."""
         self.hass = hass
@@ -1170,7 +1170,7 @@ class EntityRegistryDisabledHandler:
             self._remove_call_later()
 
         self._remove_call_later = self.hass.helpers.event.async_call_later(
-            self.RELOAD_AFTER_UPDATE_DELAY, self._handle_reload
+            RELOAD_AFTER_UPDATE_DELAY, self._handle_reload
         )
 
     async def _handle_reload(self, _now: Any) -> None:
diff --git a/tests/components/config/test_entity_registry.py b/tests/components/config/test_entity_registry.py
index 84a646ed2ef..a506135c16d 100644
--- a/tests/components/config/test_entity_registry.py
+++ b/tests/components/config/test_entity_registry.py
@@ -7,7 +7,7 @@ from homeassistant.components.config import entity_registry
 from homeassistant.const import ATTR_ICON
 from homeassistant.helpers.entity_registry import RegistryEntry
 
-from tests.common import MockEntity, MockEntityPlatform, mock_registry
+from tests.common import MockConfigEntry, MockEntity, MockEntityPlatform, mock_registry
 
 
 @pytest.fixture
@@ -162,18 +162,20 @@ async def test_update_entity(hass, client):
     msg = await client.receive_json()
 
     assert msg["result"] == {
-        "config_entry_id": None,
-        "device_id": None,
-        "area_id": "mock-area-id",
-        "disabled_by": None,
-        "platform": "test_platform",
-        "entity_id": "test_domain.world",
-        "name": "after update",
-        "icon": "icon:after update",
-        "original_name": None,
-        "original_icon": None,
-        "capabilities": None,
-        "unique_id": "1234",
+        "entity_entry": {
+            "config_entry_id": None,
+            "device_id": None,
+            "area_id": "mock-area-id",
+            "disabled_by": None,
+            "platform": "test_platform",
+            "entity_id": "test_domain.world",
+            "name": "after update",
+            "icon": "icon:after update",
+            "original_name": None,
+            "original_icon": None,
+            "capabilities": None,
+            "unique_id": "1234",
+        }
     }
 
     state = hass.states.get("test_domain.world")
@@ -208,18 +210,75 @@ async def test_update_entity(hass, client):
     msg = await client.receive_json()
 
     assert msg["result"] == {
-        "config_entry_id": None,
-        "device_id": None,
-        "area_id": "mock-area-id",
-        "disabled_by": None,
-        "platform": "test_platform",
-        "entity_id": "test_domain.world",
-        "name": "after update",
-        "icon": "icon:after update",
-        "original_name": None,
-        "original_icon": None,
-        "capabilities": None,
-        "unique_id": "1234",
+        "entity_entry": {
+            "config_entry_id": None,
+            "device_id": None,
+            "area_id": "mock-area-id",
+            "disabled_by": None,
+            "platform": "test_platform",
+            "entity_id": "test_domain.world",
+            "name": "after update",
+            "icon": "icon:after update",
+            "original_name": None,
+            "original_icon": None,
+            "capabilities": None,
+            "unique_id": "1234",
+        },
+        "reload_delay": 30,
+    }
+
+
+async def test_update_entity_require_restart(hass, client):
+    """Test updating entity."""
+    config_entry = MockConfigEntry(domain="test_platform")
+    config_entry.add_to_hass(hass)
+    mock_registry(
+        hass,
+        {
+            "test_domain.world": RegistryEntry(
+                config_entry_id=config_entry.entry_id,
+                entity_id="test_domain.world",
+                unique_id="1234",
+                # Using component.async_add_entities is equal to platform "domain"
+                platform="test_platform",
+            )
+        },
+    )
+    platform = MockEntityPlatform(hass)
+    entity = MockEntity(unique_id="1234")
+    await platform.async_add_entities([entity])
+
+    state = hass.states.get("test_domain.world")
+    assert state is not None
+
+    # UPDATE DISABLED_BY TO NONE
+    await client.send_json(
+        {
+            "id": 8,
+            "type": "config/entity_registry/update",
+            "entity_id": "test_domain.world",
+            "disabled_by": None,
+        }
+    )
+
+    msg = await client.receive_json()
+
+    assert msg["result"] == {
+        "entity_entry": {
+            "config_entry_id": config_entry.entry_id,
+            "device_id": None,
+            "area_id": None,
+            "disabled_by": None,
+            "platform": "test_platform",
+            "entity_id": "test_domain.world",
+            "name": None,
+            "icon": None,
+            "original_name": None,
+            "original_icon": None,
+            "capabilities": None,
+            "unique_id": "1234",
+        },
+        "require_restart": True,
     }
 
 
@@ -257,18 +316,20 @@ async def test_update_entity_no_changes(hass, client):
     msg = await client.receive_json()
 
     assert msg["result"] == {
-        "config_entry_id": None,
-        "device_id": None,
-        "area_id": None,
-        "disabled_by": None,
-        "platform": "test_platform",
-        "entity_id": "test_domain.world",
-        "name": "name of entity",
-        "icon": None,
-        "original_name": None,
-        "original_icon": None,
-        "capabilities": None,
-        "unique_id": "1234",
+        "entity_entry": {
+            "config_entry_id": None,
+            "device_id": None,
+            "area_id": None,
+            "disabled_by": None,
+            "platform": "test_platform",
+            "entity_id": "test_domain.world",
+            "name": "name of entity",
+            "icon": None,
+            "original_name": None,
+            "original_icon": None,
+            "capabilities": None,
+            "unique_id": "1234",
+        }
     }
 
     state = hass.states.get("test_domain.world")
@@ -335,24 +396,94 @@ async def test_update_entity_id(hass, client):
     msg = await client.receive_json()
 
     assert msg["result"] == {
-        "config_entry_id": None,
-        "device_id": None,
-        "area_id": None,
-        "disabled_by": None,
-        "platform": "test_platform",
-        "entity_id": "test_domain.planet",
-        "name": None,
-        "icon": None,
-        "original_name": None,
-        "original_icon": None,
-        "capabilities": None,
-        "unique_id": "1234",
+        "entity_entry": {
+            "config_entry_id": None,
+            "device_id": None,
+            "area_id": None,
+            "disabled_by": None,
+            "platform": "test_platform",
+            "entity_id": "test_domain.planet",
+            "name": None,
+            "icon": None,
+            "original_name": None,
+            "original_icon": None,
+            "capabilities": None,
+            "unique_id": "1234",
+        }
     }
 
     assert hass.states.get("test_domain.world") is None
     assert hass.states.get("test_domain.planet") is not None
 
 
+async def test_update_existing_entity_id(hass, client):
+    """Test update entity id to an already registered entity id."""
+    mock_registry(
+        hass,
+        {
+            "test_domain.world": RegistryEntry(
+                entity_id="test_domain.world",
+                unique_id="1234",
+                # Using component.async_add_entities is equal to platform "domain"
+                platform="test_platform",
+            ),
+            "test_domain.planet": RegistryEntry(
+                entity_id="test_domain.planet",
+                unique_id="2345",
+                # Using component.async_add_entities is equal to platform "domain"
+                platform="test_platform",
+            ),
+        },
+    )
+    platform = MockEntityPlatform(hass)
+    entities = [MockEntity(unique_id="1234"), MockEntity(unique_id="2345")]
+    await platform.async_add_entities(entities)
+
+    await client.send_json(
+        {
+            "id": 6,
+            "type": "config/entity_registry/update",
+            "entity_id": "test_domain.world",
+            "new_entity_id": "test_domain.planet",
+        }
+    )
+
+    msg = await client.receive_json()
+
+    assert not msg["success"]
+
+
+async def test_update_invalid_entity_id(hass, client):
+    """Test update entity id to an invalid entity id."""
+    mock_registry(
+        hass,
+        {
+            "test_domain.world": RegistryEntry(
+                entity_id="test_domain.world",
+                unique_id="1234",
+                # Using component.async_add_entities is equal to platform "domain"
+                platform="test_platform",
+            )
+        },
+    )
+    platform = MockEntityPlatform(hass)
+    entities = [MockEntity(unique_id="1234"), MockEntity(unique_id="2345")]
+    await platform.async_add_entities(entities)
+
+    await client.send_json(
+        {
+            "id": 6,
+            "type": "config/entity_registry/update",
+            "entity_id": "test_domain.world",
+            "new_entity_id": "another_domain.planet",
+        }
+    )
+
+    msg = await client.receive_json()
+
+    assert not msg["success"]
+
+
 async def test_remove_entity(hass, client):
     """Test removing entity."""
     registry = mock_registry(
@@ -380,3 +511,20 @@ async def test_remove_entity(hass, client):
 
     assert msg["success"]
     assert len(registry.entities) == 0
+
+
+async def test_remove_non_existing_entity(hass, client):
+    """Test removing non existing entity."""
+    mock_registry(hass, {})
+
+    await client.send_json(
+        {
+            "id": 6,
+            "type": "config/entity_registry/remove",
+            "entity_id": "test_domain.world",
+        }
+    )
+
+    msg = await client.receive_json()
+
+    assert not msg["success"]
diff --git a/tests/components/tasmota/test_sensor.py b/tests/components/tasmota/test_sensor.py
index b8e5583579b..8da08b16376 100644
--- a/tests/components/tasmota/test_sensor.py
+++ b/tests/components/tasmota/test_sensor.py
@@ -412,11 +412,7 @@ async def test_enable_status_sensor(hass, mqtt_mock, setup_tasmota):
 
     async_fire_time_changed(
         hass,
-        dt.utcnow()
-        + timedelta(
-            seconds=config_entries.EntityRegistryDisabledHandler.RELOAD_AFTER_UPDATE_DELAY
-            + 1
-        ),
+        dt.utcnow() + timedelta(seconds=config_entries.RELOAD_AFTER_UPDATE_DELAY + 1),
     )
     await hass.async_block_till_done()
 
diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py
index 5d63964d6b1..59e1b0754c0 100644
--- a/tests/test_config_entries.py
+++ b/tests/test_config_entries.py
@@ -1189,11 +1189,7 @@ async def test_reload_entry_entity_registry_works(hass):
 
     async_fire_time_changed(
         hass,
-        dt.utcnow()
-        + timedelta(
-            seconds=config_entries.EntityRegistryDisabledHandler.RELOAD_AFTER_UPDATE_DELAY
-            + 1
-        ),
+        dt.utcnow() + timedelta(seconds=config_entries.RELOAD_AFTER_UPDATE_DELAY + 1),
     )
     await hass.async_block_till_done()
 
-- 
GitLab